aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org')
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala17
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala250
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala169
-rw-r--r--core/src/main/scala/org/apache/spark/ShuffleFetcher.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala229
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java10
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java2
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function3.java36
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java2
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java5
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala1058
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala410
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala247
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala603
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala420
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/Client.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala232
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala203
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala136
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala181
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala (renamed from core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala)38
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala167
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/package.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala123
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala123
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala (renamed from core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala)12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala106
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala200
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala676
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala76
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala106
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala)26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala)64
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala196
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockException.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockInfo.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala628
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessage.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala142
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockStore.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala151
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala280
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala200
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala86
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageUtils.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala105
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala230
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/BitSet.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala154
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala272
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala53
150 files changed, 6611 insertions, 4541 deletions
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
index f87460039b..0c47afae54 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
@@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
+private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
- val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
+ "org.apache.hadoop.mapred.JobContext")
+ val ctor = klass.getDeclaredConstructor(classOf[JobConf],
+ classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
+ "org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
index 93180307fa..32429f01ac 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
-import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
+import org.apache.hadoop.conf.Configuration
+private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
- val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
+ val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
- // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
+ // First, attempt to use the old-style constructor that takes a boolean isMap
+ // (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
- classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ classOf[Int], classOf[Int])
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
- // failed, look for the new ctor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
- val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
+ // If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
+ val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
+ .asInstanceOf[Class[Enum[_]]]
+ val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
+ taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 3ef402926e..1a2ec55876 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -17,43 +17,42 @@
package org.apache.spark
-import java.util.{HashMap => JHashMap}
+import org.apache.spark.util.AppendOnlyMap
-import scala.collection.JavaConversions._
-
-/** A set of functions used to aggregate data.
- *
- * @param createCombiner function to create the initial value of the aggregation.
- * @param mergeValue function to merge a new value into the aggregation result.
- * @param mergeCombiners function to merge outputs from multiple mergeValue function.
- */
+/**
+ * A set of functions used to aggregate data.
+ *
+ * @param createCombiner function to create the initial value of the aggregation.
+ * @param mergeValue function to merge a new value into the aggregation result.
+ * @param mergeCombiners function to merge outputs from multiple mergeValue function.
+ */
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- for (kv <- iter) {
- val oldC = combiners.get(kv._1)
- if (oldC == null) {
- combiners.put(kv._1, createCombiner(kv._2))
- } else {
- combiners.put(kv._1, mergeValue(oldC, kv._2))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (iter.hasNext) {
+ kv = iter.next()
+ combiners.changeValue(kv._1, update)
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- iter.foreach { case(k, c) =>
- val oldC = combiners.get(k)
- if (oldC == null) {
- combiners.put(k, c)
- } else {
- combiners.put(k, mergeCombiners(oldC, c))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ var kc: (K, C) = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ }
+ while (iter.hasNext) {
+ kc = iter.next()
+ combiners.changeValue(kc._1, update)
}
combiners.iterator
}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
index 908ff56a6b..d9ed572da6 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -22,13 +22,17 @@ import scala.collection.mutable.HashMap
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
+ override def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
+ serializer: Serializer)
: Iterator[T] =
{
@@ -45,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
- (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
+ (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
@@ -58,9 +62,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
block.asInstanceOf[Iterator[T]]
}
case None => {
- val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
- case regex(shufId, mapId, _) =>
+ case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
@@ -74,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
- CompletionIterator[T, Iterator[T]](itr, {
+ val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
@@ -83,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
- metrics.shuffleReadMetrics = Some(shuffleMetrics)
+ context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
})
+
+ new InterruptibleIterator[T](context, completionIter)
}
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 4cf7eb96da..519ecde50a 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -18,7 +18,7 @@
package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
-import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
import org.apache.spark.rdd.RDD
@@ -28,17 +28,17 @@ import org.apache.spark.rdd.RDD
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
/** Keys of RDD splits that are being computed/loaded. */
- private val loading = new HashSet[String]()
+ private val loading = new HashSet[RDDBlockId]()
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
+ val key = RDDBlockId(rdd.id, split.index)
logDebug("Looking for partition " + key)
blockManager.get(key) match {
case Some(values) =>
// Partition is already materialized, so just return its values
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
@@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
loading.add(key)
@@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (context.runningLocally) { return computedValues }
val elements = new ArrayBuffer[Any]
elements ++= computedValues
- blockManager.put(key, elements, storageLevel, true)
+ blockManager.put(key, elements, storageLevel, tellMaster = true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
new file mode 100644
index 0000000000..1ad9240cfa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobFailed
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * A future for the result of an action. This is an extension of the Scala Future interface to
+ * support cancellation.
+ */
+trait FutureAction[T] extends Future[T] {
+ // Note that we redefine methods of the Future trait here explicitly so we can specify a different
+ // documentation (with reference to the word "action").
+
+ /**
+ * Cancels the execution of this action.
+ */
+ def cancel()
+
+ /**
+ * Blocks until this action completes.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @return this FutureAction
+ */
+ override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type
+
+ /**
+ * Awaits and returns the result (of type T) of this action.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @throws Exception exception during action execution
+ * @return the result value if the action is completed within the specific maximum wait time
+ */
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T
+
+ /**
+ * When this action is completed, either through an exception, or a value, applies the provided
+ * function.
+ */
+ def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
+
+ /**
+ * Returns whether the action has already been completed with a value or an exception.
+ */
+ override def isCompleted: Boolean
+
+ /**
+ * The value of this Future.
+ *
+ * If the future is not completed the returned value will be None. If the future is completed
+ * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if
+ * it contains an exception.
+ */
+ override def value: Option[Try[T]]
+
+ /**
+ * Blocks and returns the result of this job.
+ */
+ @throws(classOf[Exception])
+ def get(): T = Await.result(this, Duration.Inf)
+}
+
+
+/**
+ * The future holding the result of an action that triggers a single job. Examples include
+ * count, collect, reduce.
+ */
+class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
+ extends FutureAction[T] {
+
+ override def cancel() {
+ jobWaiter.cancel()
+ }
+
+ override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
+ if (!atMost.isFinite()) {
+ awaitResult()
+ } else {
+ val finishTime = System.currentTimeMillis() + atMost.toMillis
+ while (!isCompleted) {
+ val time = System.currentTimeMillis()
+ if (time >= finishTime) {
+ throw new TimeoutException
+ } else {
+ jobWaiter.wait(finishTime - time)
+ }
+ }
+ }
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ ready(atMost)(permit)
+ awaitResult() match {
+ case scala.util.Success(res) => res
+ case scala.util.Failure(e) => throw e
+ }
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
+ executor.execute(new Runnable {
+ override def run() {
+ func(awaitResult())
+ }
+ })
+ }
+
+ override def isCompleted: Boolean = jobWaiter.jobFinished
+
+ override def value: Option[Try[T]] = {
+ if (jobWaiter.jobFinished) {
+ Some(awaitResult())
+ } else {
+ None
+ }
+ }
+
+ private def awaitResult(): Try[T] = {
+ jobWaiter.awaitResult() match {
+ case JobSucceeded => scala.util.Success(resultFunc)
+ case JobFailed(e: Exception, _) => scala.util.Failure(e)
+ }
+ }
+}
+
+
+/**
+ * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
+ * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
+ * action thread if it is being blocked by a job.
+ */
+class ComplexFutureAction[T] extends FutureAction[T] {
+
+ // Pointer to the thread that is executing the action. It is set when the action is run.
+ @volatile private var thread: Thread = _
+
+ // A flag indicating whether the future has been cancelled. This is used in case the future
+ // is cancelled before the action was even run (and thus we have no thread to interrupt).
+ @volatile private var _cancelled: Boolean = false
+
+ // A promise used to signal the future.
+ private val p = promise[T]()
+
+ override def cancel(): Unit = this.synchronized {
+ _cancelled = true
+ if (thread != null) {
+ thread.interrupt()
+ }
+ }
+
+ /**
+ * Executes some action enclosed in the closure. To properly enable cancellation, the closure
+ * should use runJob implementation in this promise. See takeAsync for example.
+ */
+ def run(func: => T)(implicit executor: ExecutionContext): this.type = {
+ scala.concurrent.future {
+ thread = Thread.currentThread
+ try {
+ p.success(func)
+ } catch {
+ case e: Exception => p.failure(e)
+ } finally {
+ thread = null
+ }
+ }
+ this
+ }
+
+ /**
+ * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
+ * to enable cancellation.
+ */
+ def runJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R) {
+ // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
+ // command need to be in an atomic block.
+ val job = this.synchronized {
+ if (!cancelled) {
+ rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
+ } else {
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ // Wait for the job to complete. If the action is cancelled (with an interrupt),
+ // cancel the job and stop the execution. This is not in a synchronized block because
+ // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
+ try {
+ Await.ready(job, Duration.Inf)
+ } catch {
+ case e: InterruptedException =>
+ job.cancel()
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ /**
+ * Returns whether the promise has been cancelled.
+ */
+ def cancelled: Boolean = _cancelled
+
+ @throws(classOf[InterruptedException])
+ @throws(classOf[scala.concurrent.TimeoutException])
+ override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
+ p.future.ready(atMost)(permit)
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ p.future.result(atMost)(permit)
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
+ p.future.onComplete(func)(executor)
+ }
+
+ override def isCompleted: Boolean = p.isCompleted
+
+ override def value: Option[Try[T]] = p.future.value
+}
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
new file mode 100644
index 0000000000..56e0b8d2c0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * An iterator that wraps around an existing iterator to provide task killing functionality.
+ * It works by checking the interrupted flag in TaskContext.
+ */
+class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
+ extends Iterator[T] {
+
+ def hasNext: Boolean = !context.interrupted && delegate.hasNext
+
+ def next(): T = delegate.next()
+}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1afb1870f1..035942ad39 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,7 +20,6 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
@@ -34,7 +33,7 @@ import scala.concurrent.duration._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@@ -42,11 +41,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+ extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
- sender ! tracker.getSerializedLocations(shuffleId)
+ sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@@ -62,22 +62,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
- private var epoch: Long = 0
- private val epochLock = new java.lang.Object
+ protected var epoch: Long = 0
+ protected val epochLock = new java.lang.Object
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
- var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
-
- val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
+ private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
@@ -88,50 +85,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
+ private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
- throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
- }
- }
-
- def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses(shuffleId)
- array.synchronized {
- array(mapId) = status
- }
- }
-
- def registerMapOutputs(
- shuffleId: Int,
- statuses: Array[MapStatus],
- changeEpoch: Boolean = false) {
- mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
- if (changeEpoch) {
- incrementEpoch()
- }
- }
-
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var arrayOpt = mapStatuses.get(shuffleId)
- if (arrayOpt.isDefined && arrayOpt.get != null) {
- var array = arrayOpt.get
- array.synchronized {
- if (array(mapId) != null && array(mapId).location == bmAddress) {
- array(mapId) = null
- }
- }
- incrementEpoch()
- } else {
- throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
- }
- }
-
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@@ -170,7 +129,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
+ fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@@ -196,9 +155,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
- private def cleanup(cleanupTime: Long) {
+ protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@@ -208,15 +166,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
- // Called on master to increment the epoch number
- def incrementEpoch() {
- epochLock.synchronized {
- epoch += 1
- logDebug("Increasing epoch to " + epoch)
- }
- }
-
- // Called on master or workers to get current epoch number
+ // Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@@ -230,14 +180,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
- // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
- mapStatuses.clear()
epoch = newEpoch
+ mapStatuses.clear()
+ }
+ }
+ }
+}
+
+private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ private var cacheEpoch = epoch
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ def registerShuffle(shuffleId: Int, numMaps: Int) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ }
+
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ val array = mapStatuses(shuffleId)
+ array.synchronized {
+ array(mapId) = status
+ }
+ }
+
+ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+ if (changeEpoch) {
+ incrementEpoch()
+ }
+ }
+
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ val arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ val array = arrayOpt.get
+ array.synchronized {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
+ array(mapId) = null
+ }
}
+ incrementEpoch()
+ } else {
+ throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
- def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
+ }
+ }
+
+ def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@@ -255,7 +253,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeStatuses(statuses)
+ val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@@ -263,13 +261,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
- return bytes
+ bytes
+ }
+
+ protected override def cleanup(cleanupTime: Long) {
+ super.cleanup(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
+ override def stop() {
+ super.stop()
+ cachedSerializedStatuses.clear()
+ }
+
+ override def updateEpoch(newEpoch: Long) {
+ // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@@ -280,18 +296,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
- // Opposite of serializeStatuses.
- def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ // Opposite of serializeMapStatuses.
+ def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().
- // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
- // comment this out - nulls could be due to missing location ?
- asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
}
-}
-
-private[spark] object MapOutputTracker {
- private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
index 307c383a89..a85aa50a9b 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 1b003cc685..cc44a4c7dd 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.reflect.{ ClassTag, classTag}
@@ -53,21 +53,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.LocalSparkCluster
+import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
- ClusterScheduler}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+ SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
+import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
+import org.apache.spark.ui.SparkUI
+import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
+ TimeStampedHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -121,9 +119,9 @@ class SparkContext(
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
- private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
+ private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
- // Initalize the Spark UI
+ // Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
@@ -149,6 +147,14 @@ class SparkContext(
executorEnvs ++= environment
}
+ // Set SPARK_USER for user who is running SparkContext.
+ val sparkUser = Option {
+ Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
+ }.getOrElse {
+ SparkContext.SPARK_UNKNOWN_USER
+ }
+ executorEnvs("SPARK_USER") = sparkUser
+
// Create and start the scheduler
private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -158,9 +164,11 @@ class SparkContext(
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
- val SPARK_REGEX = """(spark://.*)""".r
- //Regular expression for connection to Mesos cluster
- val MESOS_REGEX = """(mesos://.*)""".r
+ val SPARK_REGEX = """spark://(.*)""".r
+ // Regular expression for connection to Mesos cluster
+ val MESOS_REGEX = """mesos://(.*)""".r
+ // Regular expression for connection to Simr cluster
+ val SIMR_REGEX = """simr://(.*)""".r
master match {
case "local" =>
@@ -174,7 +182,14 @@ class SparkContext(
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ val masterUrls = sparkUrl.split(",").map("spark://" + _)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
+ scheduler.initialize(backend)
+ scheduler
+
+ case SIMR_REGEX(simrUrl) =>
+ val scheduler = new ClusterScheduler(this)
+ val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
scheduler.initialize(backend)
scheduler
@@ -190,8 +205,8 @@ class SparkContext(
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
- val sparkUrl = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ val masterUrls = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
@@ -210,25 +225,24 @@ class SparkContext(
throw new SparkException("YARN mode not available ?", th)
}
}
- val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
- case _ =>
- if (MESOS_REGEX.findFirstIn(master).isEmpty) {
- logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
- }
+ case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
} else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
}
scheduler.initialize(backend)
scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
}
}
taskScheduler.start()
@@ -241,7 +255,7 @@ class SparkContext(
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val env = SparkEnv.get
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
@@ -251,8 +265,10 @@ class SparkContext(
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
- conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ Utils.getSystemProperties.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), value)
+ }
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
@@ -285,15 +301,46 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
+ @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
+ setJobGroup("", value)
+ }
+
+ /**
+ * Assigns a group id to all the jobs started by this thread until the group id is set to a
+ * different value or cleared.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark actions or jobs.
+ * Application programmers can use this method to group all those jobs together and give a
+ * group description. Once set, the Spark web UI will associate such jobs with this group.
+ *
+ * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
+ * running jobs in this group. For example,
+ * {{{
+ * // In the main thread:
+ * sc.setJobGroup("some_job_to_cancel", "some job description")
+ * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ *
+ * // In a separate thread:
+ * sc.cancelJobGroup("some_job_to_cancel")
+ * }}}
+ */
+ def setJobGroup(groupId: String, description: String) {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ }
+
+ /** Clear the job group id and its description. */
+ def clearJobGroup() {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
}
// Post init
taskScheduler.postStartHook()
- val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
- val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
+ private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
+ private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@@ -332,7 +379,7 @@ class SparkContext(
}
/**
- * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
+ * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
@@ -344,7 +391,7 @@ class SparkContext(
minSplits: Int = defaultMinSplits
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
- SparkEnv.get.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -358,24 +405,15 @@ class SparkContext(
): RDD[(K, V)] = {
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
- hadoopFile(path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
- }
-
- /**
- * Get an RDD for a Hadoop file with an arbitray InputFormat. Accept a Hadoop Configuration
- * that has already been broadcast, assuming that it's safe to use it to construct a
- * HadoopFileRDD (i.e., except for file 'path', all other configuration properties can be resued).
- */
- def hadoopFile[K, V](
- path: String,
- confBroadcast: Broadcast[SerializableWritable[Configuration]],
- inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K],
- valueClass: Class[V],
- minSplits: Int
- ): RDD[(K, V)] = {
- new HadoopFileRDD(
- this, path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
+ val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
+ new HadoopRDD(
+ this,
+ confBroadcast,
+ Some(setInputPathsFunc),
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ minSplits)
}
/**
@@ -563,7 +601,8 @@ class SparkContext(
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
- case _ => path
+ case "local" => "file:" + uri.getPath
+ case _ => path
}
addedFiles(key) = System.currentTimeMillis
@@ -657,12 +696,11 @@ class SparkContext(
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
if (path == null) {
- logWarning("null specified as parameter to addJar",
- new SparkException("null specified as parameter to addJar"))
+ logWarning("null specified as parameter to addJar")
} else {
var key = ""
if (path.contains("\\")) {
@@ -671,12 +709,27 @@ class SparkContext(
} else {
val uri = new URI(path)
key = uri.getScheme match {
+ // A JAR file which exists only on the driver node
case null | "file" =>
- if (env.hadoop.isYarnMode()) {
- logWarning("local jar specified as parameter to addJar under Yarn mode")
- return
+ if (SparkHadoopUtil.get.isYarnMode()) {
+ // In order for this to work on yarn the user must specify the --addjars option to
+ // the client to upload the file into the distributed cache to make it show up in the
+ // current working directory.
+ val fileName = new Path(uri.getPath).getName()
+ try {
+ env.httpFileServer.addJar(new File(fileName))
+ } catch {
+ case e: Exception => {
+ logError("Error adding jar (" + e + "), was the --addJars option used?")
+ throw e
+ }
+ }
+ } else {
+ env.httpFileServer.addJar(new File(uri.getPath))
}
- env.httpFileServer.addJar(new File(uri.getPath))
+ // A JAR file which exists locally on every worker node
+ case "local" =>
+ "file:" + uri.getPath
case _ =>
path
}
@@ -750,13 +803,13 @@ class SparkContext(
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
+ val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
- localProperties.get)
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
- result
}
/**
@@ -843,6 +896,42 @@ class SparkContext(
}
/**
+ * Submit a job for execution and return a FutureJob holding the result.
+ */
+ def submitJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R): SimpleFutureAction[R] =
+ {
+ val cleanF = clean(processPartition)
+ val callSite = Utils.formatSparkCallSite
+ val waiter = dagScheduler.submitJob(
+ rdd,
+ (context: TaskContext, iter: Iterator[T]) => cleanF(iter),
+ partitions,
+ callSite,
+ allowLocal = false,
+ resultHandler,
+ localProperties.get)
+ new SimpleFutureAction(waiter, resultFunc)
+ }
+
+ /**
+ * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
+ * for more information.
+ */
+ def cancelJobGroup(groupId: String) {
+ dagScheduler.cancelJobGroup(groupId)
+ }
+
+ /** Cancel all jobs that have been scheduled or are running. */
+ def cancelAllJobs() {
+ dagScheduler.cancelAllJobs()
+ }
+
+ /**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*/
@@ -859,9 +948,8 @@ class SparkContext(
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
- val env = SparkEnv.get
val path = new Path(dir)
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@@ -898,7 +986,12 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
- val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+
+ private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
@@ -925,6 +1018,8 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
+ implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index a267407c67..84750e2e85 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -25,13 +25,13 @@ import akka.remote.RemoteActorRefProvider
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
import org.apache.spark.network.ConnectionManager
import org.apache.spark.serializer.{Serializer, SerializerManager}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.api.python.PythonWorkerFactory
+import com.google.common.collect.MapMaker
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@@ -58,18 +58,9 @@ class SparkEnv (
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
- val hadoop = {
- val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
- if(yarnMode) {
- try {
- Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
- } catch {
- case th: Throwable => throw new SparkException("Unable to load YARN support", th)
- }
- } else {
- new SparkHadoopUtil
- }
- }
+ // A general, soft-reference map for metadata needed during HadoopRDD split computation
+ // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
+ private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
@@ -188,10 +179,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
- val mapOutputTracker = new MapOutputTracker()
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster()
+ } else {
+ new MapOutputTracker()
+ }
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
- new MapOutputTrackerActor(mapOutputTracker))
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 2bab9d6e3d..103a1c2051 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-
+import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
-import java.io.IOException
import java.util.Date
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
-class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
+private[apache]
+class SparkHadoopWriter(@transient jobConf: JobConf)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
getOutputCommitter().setupTask(getTaskContext())
- writer = getOutputFormat().getRecordWriter(
- fs, conf.value, outputName, Reporter.NULL)
+ writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
- if (writer!=null) {
- //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
+ if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
}
+private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index c2c358c7ad..cae983ed4c 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -17,21 +17,30 @@
package org.apache.spark
-import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.executor.TaskMetrics
+
class TaskContext(
val stageId: Int,
- val splitId: Int,
+ val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
- val taskMetrics: TaskMetrics = TaskMetrics.empty()
+ @volatile var interrupted: Boolean = false,
+ private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
- @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @deprecated("use partitionId", "0.8.1")
+ def splitId = partitionId
+
+ // List of callback functions to execute when the task completes.
+ @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
- // Add a callback function to be executed on task completion. An example use
- // is for HadoopRDD to register a callback to close the input stream.
+ /**
+ * Add a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * @param f Callback function.
+ */
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8466c2a004..c1e5e04b31 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure(
*/
private[spark] case object TaskResultLost extends TaskEndReason
+private[spark] case object TaskKilled extends TaskEndReason
+
private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index f0a1960a1b..e5e20dbb66 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -51,6 +51,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
+
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
@@ -84,6 +97,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
+
+ /**
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 899e17d4fa..eeea0eddb1 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -66,6 +66,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -96,6 +109,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
@@ -599,4 +623,15 @@ object JavaPairRDD {
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+
+
+ /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
+ def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
+ implicit val cmk: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ new JavaPairRDD[K, V](rdd.rdd)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 9968bc8e5f..c47657f512 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -43,9 +43,17 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -76,6 +84,17 @@ JavaRDDLike[T, JavaRDD[T]] {
rdd.coalesce(numPartitions, shuffle)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
index 4830067f7a..3e85052cd0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
-public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
+public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
-
- public abstract Iterable<Double> call(T t);
-
- @Override
- public final Iterable<Double> apply(T t) { return call(t); }
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
index ed92d31af5..5e9b8c48b8 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
@@ -27,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
-
- public abstract Double call(T t) throws Exception;
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
index b7c0d78e33..ed8fea97fc 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -23,8 +23,5 @@ import scala.reflect.ClassTag
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- @throws(classOf[Exception])
- def call(x: T) : java.lang.Iterable[R]
-
def elementType() : ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
index 7a505df4be..aae1349c5e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -23,8 +23,5 @@ import scala.reflect.ClassTag
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- @throws(classOf[Exception])
- def call(a: A, b:B) : java.lang.Iterable[C]
-
def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
index e97116986f..49e661a376 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -32,7 +32,7 @@ public abstract class Function<T, R> extends WrappedFunction1<T, R> implements S
public abstract R call(T t) throws Exception;
public ClassTag<R> returnType() {
- return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
+ return ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
new file mode 100644
index 0000000000..fb1deceab5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+import scala.runtime.AbstractFunction2;
+
+import java.io.Serializable;
+
+/**
+ * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
+ */
+public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
+ implements Serializable {
+
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
index fbd0cdabe0..ca485b3cc2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -33,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
- public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
-
public ClassTag<K> keyType() {
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
index f09559627d..cbe2306026 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -28,12 +28,9 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
-public abstract class PairFunction<T, K, V>
- extends WrappedFunction1<T, Tuple2<K, V>>
+public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
- public abstract Tuple2<K, V> call(T t) throws Exception;
-
public ClassTag<K> keyType() {
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
new file mode 100644
index 0000000000..d314dbdf1d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+import scala.runtime.AbstractFunction3
+
+/**
+ * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
+ * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
+ * isn't marked to allow that).
+ */
+private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
+ extends AbstractFunction3[T1, T2, T3, R] {
+ @throws(classOf[Exception])
+ def call(t1: T1, t2: T2, t3: T3): R
+
+ final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 4d887cf195..53b53df9ac 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
-class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
deleted file mode 100644
index 93e7815ab5..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
+++ /dev/null
@@ -1,1058 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
-
-private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
- extends Broadcast[T](id)
- with Logging
- with Serializable {
-
- def value = value_
-
- def blockId: String = "broadcast_" + id
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var hasBlocksBitVector: BitSet = null
- @transient var numCopiesSent: Array[Int] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = new AtomicInteger(0)
-
- // Used ONLY by driver to track how many unique blocks have been sent out
- @transient var sentBlocks = new AtomicInteger(0)
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
-
- // Used only in driver
- @transient var guideMR: GuideMultipleRequests = null
-
- // Used only in Workers
- @transient var ttGuide: TalkToGuide = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks.set(variableInfo.totalBlocks)
-
- // Guide has all the blocks
- hasBlocksBitVector = new BitSet(totalBlocks)
- hasBlocksBitVector.set(0, totalBlocks)
-
- // Guide still hasn't sent any block
- numCopiesSent = new Array[Int](totalBlocks)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val driverSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- hasBlocksBitVector.synchronized {
- driverSource.hasBlocksBitVector = hasBlocksBitVector
- }
-
- // In the beginning, this is the only known source to Guide
- listOfSources += driverSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- // Start local ServeMultipleRequests thread first
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- // Initialize variables in the worker node. Driver sends everything as 0/null
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- hasBlocksBitVector = null
- numCopiesSent = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = new AtomicInteger(0)
-
- listenPortLock = new Object
- totalBlocksLock = new Object
-
- serveMR = null
- ttGuide = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- listOfSources = ListBuffer[SourceInfo]()
-
- stopBroadcast = false
- }
-
- private def getLocalSourceInfo: SourceInfo = {
- // Wait till hostName and listenPort are OK
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Wait till totalBlocks and totalBytes are OK
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- var localSourceInfo = SourceInfo(
- hostAddress, listenPort, totalBlocks, totalBytes)
-
- localSourceInfo.hasBlocks = hasBlocks.get
-
- hasBlocksBitVector.synchronized {
- localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
- }
-
- return localSourceInfo
- }
-
- // Add new SourceInfo to the listOfSources. Update if it exists already.
- // Optimizing just by OR-ing the BitVectors was BAD for performance
- private def addToListOfSources(newSourceInfo: SourceInfo) {
- listOfSources.synchronized {
- if (listOfSources.contains(newSourceInfo)) {
- listOfSources = listOfSources - newSourceInfo
- }
- listOfSources += newSourceInfo
- }
- }
-
- private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) {
- newSourceInfos.foreach { newSourceInfo =>
- addToListOfSources(newSourceInfo)
- }
- }
-
- class TalkToGuide(gInfo: SourceInfo)
- extends Thread with Logging {
- override def run() {
-
- // Keep exchaning information until all blocks have been received
- while (hasBlocks.get < totalBlocks) {
- talkOnce
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
- }
-
- // Talk one more time to let the Guide know of reception completion
- talkOnce
- }
-
- // Connect to Guide and send this worker's information
- private def talkOnce {
- var clientSocketToGuide: Socket = null
- var oosGuide: ObjectOutputStream = null
- var oisGuide: ObjectInputStream = null
-
- clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
- oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream)
- oosGuide.flush()
- oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream)
-
- // Send local information
- oosGuide.writeObject(getLocalSourceInfo)
- oosGuide.flush()
-
- // Receive source information from Guide
- var suitableSources =
- oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logDebug("Received suitableSources from Driver " + suitableSources)
-
- addToListOfSources(suitableSources)
-
- oisGuide.close()
- oosGuide.close()
- clientSocketToGuide.close()
- }
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Setup initial states of variables
- totalBlocks = gInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- hasBlocksBitVector = new BitSet(totalBlocks)
- numCopiesSent = new Array[Int](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = gInfo.totalBytes
-
- // Start ttGuide to periodically talk to the Guide
- var ttGuide = new TalkToGuide(gInfo)
- ttGuide.setDaemon(true)
- ttGuide.start()
- logInfo("TalkToGuide started...")
-
- // Start pController to run TalkToPeer threads
- var pcController = new PeerChatterController
- pcController.setDaemon(true)
- pcController.start()
- logInfo("PeerChatterController started...")
-
- // FIXME: Must fix this. This might never break if broadcast fails.
- // We should be able to break and send false. Also need to kill threads
- while (hasBlocks.get < totalBlocks) {
- Thread.sleep(MultiTracker.MaxKnockInterval)
- }
-
- return true
- }
-
- class PeerChatterController
- extends Thread with Logging {
- private var peersNowTalking = ListBuffer[SourceInfo]()
- // TODO: There is a possible bug with blocksInRequestBitVector when a
- // certain bit is NOT unset upon failure resulting in an infinite loop.
- private var blocksInRequestBitVector = new BitSet(totalBlocks)
-
- override def run() {
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- while (hasBlocks.get < totalBlocks) {
- var numThreadsToCreate = 0
- listOfSources.synchronized {
- numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
- threadPool.getActiveCount
- }
-
- while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
- var peerToTalkTo = pickPeerToTalkToRandom
-
- if (peerToTalkTo != null)
- logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
- else
- logDebug("No peer chosen...")
-
- if (peerToTalkTo != null) {
- threadPool.execute(new TalkToPeer(peerToTalkTo))
-
- // Add to peersNowTalking. Remove in the thread. We have to do this
- // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
- peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before starting some more threads
- Thread.sleep(MultiTracker.MinKnockInterval)
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- // Right now picking the one that has the most blocks this peer wants
- // Also picking peer randomly if no one has anything interesting
- private def pickPeerToTalkToRandom: SourceInfo = {
- var curPeer: SourceInfo = null
- var curMax = 0
-
- logDebug("Picking peers to talk to...")
-
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Select the peer that has the most blocks that this receiver does not
- peersNotInUse.foreach { eachSource =>
- var tempHasBlocksBitVector: BitSet = null
- hasBlocksBitVector.synchronized {
- tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
- tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size)
- tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector)
-
- if (tempHasBlocksBitVector.cardinality > curMax) {
- curPeer = eachSource
- curMax = tempHasBlocksBitVector.cardinality
- }
- }
-
- // Always picking randomly
- if (curPeer == null && peersNotInUse.size > 0) {
- // Pick uniformly the i'th required peer
- var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
-
- var peerIter = peersNotInUse.iterator
- curPeer = peerIter.next
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
- }
-
- return curPeer
- }
-
- // Picking peer with the weight of rare blocks it has
- private def pickPeerToTalkToRarestFirst: SourceInfo = {
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Count the number of copies of each block in the neighborhood
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // A block is considered rare if there are at most 2 copies of that block
- // This CONSTANT could be a function of the neighborhood size
- var rareBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
- rareBlocksIndices += i
- }
- }
-
- // Find peers with rare blocks
- var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]()
- var totalRareBlocks = 0
-
- peersNotInUse.foreach { eachPeer =>
- var hasRareBlocks = 0
- rareBlocksIndices.foreach { rareBlock =>
- if (eachPeer.hasBlocksBitVector.get(rareBlock)) {
- hasRareBlocks += 1
- }
- }
-
- if (hasRareBlocks > 0) {
- peersWithRareBlocks += ((eachPeer, hasRareBlocks))
- }
- totalRareBlocks += hasRareBlocks
- }
-
- // Select a peer from peersWithRareBlocks based on weight calculated from
- // unique rare blocks
- var selectedPeerToTalkTo: SourceInfo = null
-
- if (peersWithRareBlocks.size > 0) {
- // Sort the peers based on how many rare blocks they have
- peersWithRareBlocks.sortBy(_._2)
-
- var randomNumber = MultiTracker.ranGen.nextDouble
- var tempSum = 0.0
-
- var i = 0
- do {
- tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks)
- if (tempSum >= randomNumber) {
- selectedPeerToTalkTo = peersWithRareBlocks(i)._1
- }
- i += 1
- } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null)
- }
-
- if (selectedPeerToTalkTo == null) {
- selectedPeerToTalkTo = pickPeerToTalkToRandom
- }
-
- return selectedPeerToTalkTo
- }
-
- class TalkToPeer(peerToTalkTo: SourceInfo)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- override def run() {
- // TODO: There is a possible bug here regarding blocksInRequestBitVector
- var blockToAskFor = -1
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run() {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
-
- logInfo("TalkToPeer started... => " + peerToTalkTo)
-
- try {
- // Connect to the source
- peerSocketToSource =
- new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource =
- new ObjectInputStream(peerSocketToSource.getInputStream)
-
- // Receive latest SourceInfo from peerToTalkTo
- var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
- // Update listOfSources
- addToListOfSources(newPeerToTalkTo)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
-
- var keepReceiving = true
-
- while (hasBlocks.get < totalBlocks && keepReceiving) {
- blockToAskFor =
- pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector)
-
- // No block to request
- if (blockToAskFor < 0) {
- // Nothing to receive from newPeerToTalkTo
- keepReceiving = false
- } else {
- // Let other threads know that blockToAskFor is being requested
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor)
- }
-
- // Start with sending the blockID
- oosSource.writeObject(blockToAskFor)
- oosSource.flush()
-
- // CHANGED: Driver might send some other block than the one
- // requested to ensure fast spreading of all blocks.
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
-
- if (!hasBlocksBitVector.get(bcBlock.blockID)) {
- arrayOfBlocks(bcBlock.blockID) = bcBlock
-
- // Update the hasBlocksBitVector first
- hasBlocksBitVector.synchronized {
- hasBlocksBitVector.set(bcBlock.blockID)
- hasBlocks.getAndIncrement
- }
-
- // Some block(may NOT be blockToAskFor) has arrived.
- // In any case, blockToAskFor is not in request any more
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
-
- // Reset blockToAskFor to -1. Else it will be considered missing
- blockToAskFor = -1
- }
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logError("TalktoPeer had a " + e)
- // FIXME: Remove 'newPeerToTalkTo' from listOfSources
- // We probably should have the following in some form, but not
- // really here. This exception can happen if the sender just breaks connection
- // listOfSources.synchronized {
- // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
- // listOfSources = listOfSources - peerToTalkTo
- // }
- }
- } finally {
- // blockToAskFor != -1 => there was an exception
- if (blockToAskFor != -1) {
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
- }
-
- cleanUpConnections()
- }
- }
-
- // Right now it picks a block uniformly that this peer does not have
- private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Pick uniformly the i'th required block
- var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
- var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
-
- while (i > 0) {
- pickedBlockIndex =
- needBlocksBitVector.nextSetBit(pickedBlockIndex + 1)
- i -= 1
- }
-
- return pickedBlockIndex
- }
- }
-
- // Pick the block that seems to be the rarest across sources
- private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Count the number of copies for each block across all sources
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // Find the minimum
- var minVal = Integer.MAX_VALUE
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) {
- minVal = numCopiesPerBlock(i)
- }
- }
-
- // Find the blocks with the least copies that this peer does not have
- var minBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) {
- minBlocksIndices += i
- }
- }
-
- // Now select a random index from minBlocksIndices
- if (minBlocksIndices.size == 0) {
- return -1
- } else {
- // Pick uniformly the i'th index
- var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
- return minBlocksIndices(i)
- }
- }
- }
-
- private def cleanUpConnections() {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
-
- // Delete from peersNowTalking
- peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
- }
- }
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close the socket here; else, thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- listOfSources.foreach { sourceInfo =>
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Throw away whatever comes in
- gisSource.readObject.asInstanceOf[SourceInfo]
-
- // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
- gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sourceInfo: SourceInfo = null
- private var selectedSources: ListBuffer[SourceInfo] = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its information
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Select a suitable source and send it back to the worker
- selectedSources = selectSuitableSources(sourceInfo)
- logDebug("Sending selectedSources:" + selectedSources)
- oos.writeObject(selectedSources)
- oos.flush()
-
- // Add this source to the listOfSources
- addToListOfSources(sourceInfo)
- } catch {
- case e: Exception => {
- // Assuming exception caused by receiver failure: remove
- if (listOfSources != null) {
- listOfSources.synchronized { listOfSources -= sourceInfo }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Randomly select some sources to send back
- private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
- var selectedSources = ListBuffer[SourceInfo]()
-
- // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
- // then add skipSourceInfo to setOfCompletedSources. Return blank.
- if (skipSourceInfo.hasBlocks == totalBlocks) {
- setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
- return selectedSources
- }
-
- listOfSources.synchronized {
- if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
- selectedSources = listOfSources.clone
- } else {
- var picksLeft = MultiTracker.MaxPeersInGuideResponse
- var alreadyPicked = new BitSet(listOfSources.size)
-
- while (picksLeft > 0) {
- var i = -1
-
- do {
- i = MultiTracker.ranGen.nextInt(listOfSources.size)
- } while (alreadyPicked.get(i))
-
- var peerIter = listOfSources.iterator
- var curPeer = peerIter.next
-
- // Set the BitSet before i is decremented
- alreadyPicked.set(i)
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
-
- selectedSources += curPeer
-
- picksLeft = picksLeft - 1
- }
- }
- }
-
- // Remove the receiving source (if present)
- selectedSources = selectedSources - skipSourceInfo
-
- return selectedSources
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- // Server at most MultiTracker.MaxChatSlots peers
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ServeSingleRequest is running")
-
- override def run() {
- try {
- // Send latest local SourceInfo to the receiver
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(getLocalSourceInfo)
- oos.flush()
-
- // Receive latest SourceInfo from the receiver
- var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- addToListOfSources(rxSourceInfo)
- }
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = MultiTracker.MaxChatBlocks
-
- while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
- // Receive which block to send
- var blockToSend = ois.readObject.asInstanceOf[Int]
-
- // If it is driver AND at least one copy of each block has not been
- // sent out already, MODIFY blockToSend
- if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
- blockToSend = sentBlocks.getAndIncrement
- }
-
- // Send the block
- sendBlock(blockToSend)
- rxSourceInfo.hasBlocksBitVector.set(blockToSend)
-
- numBlocksToSend -= 1
-
- // Receive latest SourceInfo from the receiver
- rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
- addToListOfSources(rxSourceInfo)
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= MultiTracker.MaxChatTime &&
- threadPool.getQueue.size > 0) {
- keepSending = false
- }
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendBlock(blockToSend: Int) {
- try {
- oos.writeObject(arrayOfBlocks(blockToSend))
- oos.flush()
- } catch {
- case e: Exception => logError("sendBlock had a " + e)
- }
- logDebug("Sent block: " + blockToSend + " to " + clientSocket)
- }
- }
- }
-}
-
-private[spark] class BitTorrentBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new BitTorrentBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 9db26ae6de..609464e38d 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.{HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet}
-
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
- def blockId: String = "broadcast_" + id
+ def blockId = BroadcastBlockId(id)
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@@ -82,7 +81,7 @@ private object HttpBroadcast extends Logging {
private var server: HttpServer = null
private val files = new TimeStampedHashSet[String]
- private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
+ private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
private lazy val compressionCodec = CompressionCodec.createCodec()
@@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
}
def write(id: Long, value: Any) {
- val file = new File(broadcastDir, "broadcast-" + id)
+ val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
- val url = serverUri + "/broadcast-" + id
+ val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
if (compress) {
compressionCodec.compressedInputStream(new URL(url).openStream())
diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
deleted file mode 100644
index 21ec94659e..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
+++ /dev/null
@@ -1,410 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.Random
-
-import scala.collection.mutable.Map
-
-import org.apache.spark._
-import org.apache.spark.util.Utils
-
-private object MultiTracker
-extends Logging {
-
- // Tracker Messages
- val REGISTER_BROADCAST_TRACKER = 0
- val UNREGISTER_BROADCAST_TRACKER = 1
- val FIND_BROADCAST_TRACKER = 2
-
- // Map to keep track of guides of ongoing broadcasts
- var valueToGuideMap = Map[Long, SourceInfo]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var _isDriver = false
-
- private var stopBroadcast = false
-
- private var trackMV: TrackMultipleValues = null
-
- def initialize(__isDriver: Boolean) {
- synchronized {
- if (!initialized) {
- _isDriver = __isDriver
-
- if (isDriver) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
-
- // Set DriverHostAddress to the driver's IP address for the slaves to read
- System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
- }
-
- initialized = true
- }
- }
- }
-
- def stop() {
- stopBroadcast = true
- }
-
- // Load common parameters
- private var DriverHostAddress_ = System.getProperty(
- "spark.MultiTracker.DriverHostAddress", "")
- private var DriverTrackerPort_ = System.getProperty(
- "spark.broadcast.driverTrackerPort", "11111").toInt
- private var BlockSize_ = System.getProperty(
- "spark.broadcast.blockSize", "4096").toInt * 1024
- private var MaxRetryCount_ = System.getProperty(
- "spark.broadcast.maxRetryCount", "2").toInt
-
- private var TrackerSocketTimeout_ = System.getProperty(
- "spark.broadcast.trackerSocketTimeout", "50000").toInt
- private var ServerSocketTimeout_ = System.getProperty(
- "spark.broadcast.serverSocketTimeout", "10000").toInt
-
- private var MinKnockInterval_ = System.getProperty(
- "spark.broadcast.minKnockInterval", "500").toInt
- private var MaxKnockInterval_ = System.getProperty(
- "spark.broadcast.maxKnockInterval", "999").toInt
-
- // Load TreeBroadcast config params
- private var MaxDegree_ = System.getProperty(
- "spark.broadcast.maxDegree", "2").toInt
-
- // Load BitTorrentBroadcast config params
- private var MaxPeersInGuideResponse_ = System.getProperty(
- "spark.broadcast.maxPeersInGuideResponse", "4").toInt
-
- private var MaxChatSlots_ = System.getProperty(
- "spark.broadcast.maxChatSlots", "4").toInt
- private var MaxChatTime_ = System.getProperty(
- "spark.broadcast.maxChatTime", "500").toInt
- private var MaxChatBlocks_ = System.getProperty(
- "spark.broadcast.maxChatBlocks", "1024").toInt
-
- private var EndGameFraction_ = System.getProperty(
- "spark.broadcast.endGameFraction", "0.95").toDouble
-
- def isDriver = _isDriver
-
- // Common config params
- def DriverHostAddress = DriverHostAddress_
- def DriverTrackerPort = DriverTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- // TreeBroadcast configs
- def MaxDegree = MaxDegree_
-
- // BitTorrentBroadcast configs
- def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
-
- def MaxChatSlots = MaxChatSlots_
- def MaxChatTime = MaxChatTime_
- def MaxChatBlocks = MaxChatBlocks_
-
- def EndGameFraction = EndGameFraction_
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(DriverTrackerPort)
- logInfo("TrackMultipleValues started at " + serverSocket)
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(TrackerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- if (stopBroadcast) {
- logInfo("Stopping TrackMultipleValues...")
- }
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // First, read message type
- val messageType = ois.readObject.asInstanceOf[Int]
-
- if (messageType == REGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
- // Receive hostAddress and listenPort
- val gInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Add to the map
- valueToGuideMap.synchronized {
- valueToGuideMap += (id -> gInfo)
- }
-
- logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- // Remove from the map
- valueToGuideMap.synchronized {
- valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
- }
-
- logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == FIND_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- var gInfo =
- if (valueToGuideMap.contains(id)) valueToGuideMap(id)
- else SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
-
- // Send reply back
- oos.writeObject(gInfo)
- oos.flush()
- } else {
- throw new SparkException("Undefined messageType at TrackMultipleValues")
- }
- } catch {
- case e: Exception => {
- logError("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- def getGuideInfo(variableLong: Long): SourceInfo = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out GuideInfo
- clientSocketToTracker =
- new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send messageType/intention
- oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
- oosTracker.flush()
-
- // Send Long and receive GuideInfo
- oosTracker.writeObject(variableLong)
- oosTracker.flush()
- gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
- } catch {
- case e: Exception => logError("getGuideInfo had a " + e)
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
-
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
-
- logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
- return gInfo
- }
-
- def registerBroadcast(id: Long, gInfo: SourceInfo) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(REGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Send this tracker's information
- oosST.writeObject(gInfo)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- def unregisterBroadcast(id: Long) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- // Helper method to convert an object to Array[BroadcastBlock]
- def blockifyObject[IN](obj: IN): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- oos.writeObject(obj)
- oos.close()
- baos.close()
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream(byteArray)
-
- var blockNum = (byteArray.length / BlockSize)
- if (byteArray.length % BlockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock](blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, BlockSize)) {
- val thisBlockSize = math.min(BlockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
-
- retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
- blockID += 1
- }
- bais.close()
-
- var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- // Helper method to convert Array[BroadcastBlock] to object
- def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
- totalBytes: Int,
- totalBlocks: Int): OUT = {
-
- var retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject(retByteArray)
- }
-
- private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
- }
- val retVal = in.readObject.asInstanceOf[OUT]
- in.close()
- return retVal
- }
-}
-
-private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
-extends Serializable
-
-private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
- totalBlocks: Int,
- totalBytes: Int)
-extends Serializable {
- @transient var hasBlocks = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
deleted file mode 100644
index baa1fd6da4..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.util.BitSet
-
-import org.apache.spark._
-
-/**
- * Used to keep and pass around information of peers involved in a broadcast
- */
-private[spark] case class SourceInfo (hostAddress: String,
- listenPort: Int,
- totalBlocks: Int = SourceInfo.UnusedParam,
- totalBytes: Int = SourceInfo.UnusedParam)
-extends Comparable[SourceInfo] with Logging {
-
- var currentLeechers = 0
- var receptionFailed = false
-
- var hasBlocks = 0
- var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
-
- // Ascending sort based on leecher count
- def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
-}
-
-/**
- * Helper Object of SourceInfo for its constants
- */
-private[spark] object SourceInfo {
- // Broadcast has not started yet! Should never happen.
- val TxNotStartedRetry = -1
- // Broadcast has already finished. Try default mechanism.
- val TxOverGoToDefault = -3
- // Other constants
- val StopBroadcast = -2
- val UnusedParam = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
new file mode 100644
index 0000000000..073a0a5029
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+
+import scala.math
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def broadcastId = BroadcastBlockId(id)
+
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[TorrentBlock] = null
+ @transient var totalBlocks = -1
+ @transient var totalBytes = -1
+ @transient var hasBlocks = 0
+
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ var tInfo = TorrentBroadcast.blockifyObject(value_)
+
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ hasBlocks = tInfo.totalBlocks
+
+ // Store meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ }
+
+ // Store individual pieces
+ for (i <- 0 until totalBlocks) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ }
+ }
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ val start = System.nanoTime
+ logInfo("Started reading broadcast variable " + id)
+
+ // Initialize @transient variables that will receive garbage values from the master.
+ resetWorkerVariables()
+
+ if (receiveBroadcast(id)) {
+ value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+
+ // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ // This creates a tradeoff between memory usage and latency.
+ // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+
+ // Remove arrayOfBlocks from memory once value_ is on local cache
+ resetWorkerVariables()
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ private def resetWorkerVariables() {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ // Receive meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ var attemptId = 10
+ while (attemptId > 0 && totalBlocks == -1) {
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(metaId) match {
+ case Some(x) =>
+ val tInfo = x.asInstanceOf[TorrentInfo]
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
+ hasBlocks = 0
+
+ case None =>
+ Thread.sleep(500)
+ }
+ }
+ attemptId -= 1
+ }
+ if (totalBlocks == -1) {
+ return false
+ }
+
+ // Receive actual blocks
+ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
+ for (pid <- recvOrder) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(pieceId) match {
+ case Some(x) =>
+ arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
+ hasBlocks += 1
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ }
+
+ (hasBlocks == totalBlocks)
+ }
+
+}
+
+private object TorrentBroadcast
+extends Logging {
+
+ private var initialized = false
+
+ def initialize(_isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ initialized = false
+ }
+
+ val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
+
+ def blockifyObject[T](obj: T): TorrentInfo = {
+ val byteArray = Utils.serialize[T](obj)
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BLOCK_SIZE)
+ if (byteArray.length % BLOCK_SIZE != 0)
+ blockNum += 1
+
+ var retVal = new Array[TorrentBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
+ val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ tInfo.hasBlocks = blockNum
+
+ return tInfo
+ }
+
+ def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
+ }
+ Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
+ }
+
+}
+
+private[spark] case class TorrentBlock(
+ blockID: Int,
+ byteArray: Array[Byte])
+ extends Serializable
+
+private[spark] case class TorrentInfo(
+ @transient arrayOfBlocks : Array[TorrentBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+ extends Serializable {
+
+ @transient var hasBlocks = 0
+}
+
+private[spark] class TorrentBroadcastFactory
+ extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
deleted file mode 100644
index 80c97ca073..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
+++ /dev/null
@@ -1,603 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{Comparator, Random, UUID}
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
-
-private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
-extends Broadcast[T](id) with Logging with Serializable {
-
- def value = value_
-
- def blockId = "broadcast_" + id
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- listOfSources += masterSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- var clientSocketToDriver: Socket = null
- var oosDriver: ObjectOutputStream = null
- var oisDriver: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- // Connect to Driver and send this worker's Information
- clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
- oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
- oosDriver.flush()
- oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
-
- logDebug("Connected to Driver's guiding object")
-
- // Send local source information
- oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
- oosDriver.flush()
-
- // Receive source information from Driver
- var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = sourceInfo.totalBytes
-
- logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission(sourceInfo)
- val time = (System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Driver will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Driver
- oosDriver.writeObject(sourceInfo)
-
- if (oisDriver != null) {
- oisDriver.close()
- }
- if (oosDriver != null) {
- oosDriver.close()
- }
- if (clientSocketToDriver != null) {
- clientSocketToDriver.close()
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return (hasBlocks == totalBlocks)
- }
-
- /**
- * Tries to receive broadcast from the source and returns Boolean status.
- * This might be called multiple times to retry a defined number of times.
- */
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
-
- logDebug("Inside receiveSingleTransmission")
- logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush()
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
-
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
- }
- } catch {
- case e: Exception => logError("receiveSingleTransmission had a " + e)
- } finally {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close()
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close() the socket here; else, the thread will close() it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- var listIter = listOfSources.iterator
- while (listIter.hasNext) {
- var sourceInfo = listIter.next
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal
- gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource(sourceInfo)
- logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject(selectedSourceInfo)
- oos.flush()
-
- // Add this new (if it can finish) source to the list of sources
- thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes)
- logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
- listOfSources += thisWorkerInfo
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in listOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert(listOfSources.contains(selectedSourceInfo))
-
- // Remove first
- // (Currently removing a source based on just one failure notification!)
- listOfSources = listOfSources - selectedSourceInfo
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources.synchronized {
- setOfCompletedSources += thisWorkerInfo
- }
-
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
- }
- } catch {
- case e: Exception => {
- // Remove failed worker from listOfSources and update leecherCount of
- // corresponding source worker
- listOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- listOfSources = listOfSources - selectedSourceInfo
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
-
- // Remove thisWorkerInfo
- if (listOfSources != null) {
- listOfSources = listOfSources - thisWorkerInfo
- }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Assuming the caller to have a synchronized block on listOfSources
- // Select one with the most leechers. This will level-wise fill the tree
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- var maxLeechers = -1
- var selectedSource: SourceInfo = null
-
- listOfSources.foreach { source =>
- if ((source.hostAddress != skipSourceInfo.hostAddress ||
- source.listenPort != skipSourceInfo.listenPort) &&
- source.currentLeechers < MultiTracker.MaxDegree &&
- source.currentLeechers > maxLeechers) {
- selectedSource = source
- maxLeechers = source.currentLeechers
- }
- }
-
- // Update leecher count
- selectedSource.currentLeechers += 1
- return selectedSource
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
-
- var threadPool = Utils.newDaemonCachedThreadPool()
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => { }
- }
-
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run() {
- try {
- logInfo("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- // If not a valid range, stop broadcast
- if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- sendObject
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendObject() {
- // Wait till receiving the SourceInfo from Driver
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized { hasBlocksLock.wait() }
- }
- try {
- oos.writeObject(arrayOfBlocks(i))
- oos.flush()
- } catch {
- case e: Exception => logError("sendObject had a " + e)
- }
- logDebug("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
-}
-
-private[spark] class TreeBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new TreeBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 1cfff5e565..275331724a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -21,12 +21,14 @@ import scala.collection.immutable.List
import org.apache.spark.deploy.ExecutorState.ExecutorState
import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
+import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.ExecutorRunner
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
+/** Contains messages sent between Scheduler actor nodes. */
private[deploy] object DeployMessages {
// Worker to Master
@@ -52,17 +54,20 @@ private[deploy] object DeployMessages {
exitStatus: Option[Int])
extends DeployMessage
+ case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription])
+
case class Heartbeat(workerId: String) extends DeployMessage
// Master to Worker
- case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
+ case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
- case class KillExecutor(appId: String, execId: Int) extends DeployMessage
+ case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
+ masterUrl: String,
appId: String,
execId: Int,
appDesc: ApplicationDescription,
@@ -76,9 +81,11 @@ private[deploy] object DeployMessages {
case class RegisterApplication(appDescription: ApplicationDescription)
extends DeployMessage
+ case class MasterChangeAcknowledged(appId: String)
+
// Master to Client
- case class RegisteredApplication(appId: String) extends DeployMessage
+ case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
@@ -94,6 +101,10 @@ private[deploy] object DeployMessages {
case object StopClient
+ // Master to Worker & Client
+
+ case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
+
// MasterWebUI To Master
case object RequestMasterState
@@ -101,7 +112,8 @@ private[deploy] object DeployMessages {
// Master to MasterWebUI
case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
- activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+ activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
+ status: MasterState) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
@@ -123,12 +135,7 @@ private[deploy] object DeployMessages {
assert (port > 0)
}
- // Actor System to Master
-
- case object CheckForWorkerTimeOut
-
- case object RequestWebUIPort
-
- case class WebUIPortResponse(webUIBoundPort: Int)
+ // Actor System to Worker
+ case object SendHeartbeat
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala
new file mode 100644
index 0000000000..2abf0b69dd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+/**
+ * Used to send state on-the-wire about Executors from Worker to Master.
+ * This state is sufficient for the Master to reconstruct its internal data structures during
+ * failover.
+ */
+private[spark] class ExecutorDescription(
+ val appId: String,
+ val execId: Int,
+ val cores: Int,
+ val state: ExecutorState.Value)
+ extends Serializable {
+
+ override def toString: String =
+ "ExecutorState(appId=%s, execId=%d, cores=%d, state=%s)".format(appId, execId, cores, state)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
new file mode 100644
index 0000000000..668032a3a2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -0,0 +1,420 @@
+/*
+ *
+ * * Licensed to the Apache Software Foundation (ASF) under one or more
+ * * contributor license agreements. See the NOTICE file distributed with
+ * * this work for additional information regarding copyright ownership.
+ * * The ASF licenses this file to You under the Apache License, Version 2.0
+ * * (the "License"); you may not use this file except in compliance with
+ * * the License. You may obtain a copy of the License at
+ * *
+ * * http://www.apache.org/licenses/LICENSE-2.0
+ * *
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS,
+ * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * * See the License for the specific language governing permissions and
+ * * limitations under the License.
+ *
+ */
+
+package org.apache.spark.deploy
+
+import java.io._
+import java.net.URL
+import java.util.concurrent.TimeoutException
+
+import scala.concurrent.{Await, future, promise}
+import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.collection.mutable.ListBuffer
+import scala.sys.process._
+
+import net.liftweb.json.JsonParser
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.deploy.master.RecoveryState
+
+/**
+ * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master.
+ * In order to mimic a real distributed cluster more closely, Docker is used.
+ * Execute using
+ * ./spark-class org.apache.spark.deploy.FaultToleranceTest
+ *
+ * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
+ * - spark.deploy.recoveryMode=ZOOKEEPER
+ * - spark.deploy.zookeeper.url=172.17.42.1:2181
+ * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
+ *
+ * Unfortunately, due to the Docker dependency this suite cannot be run automatically without a
+ * working installation of Docker. In addition to having Docker, the following are assumed:
+ * - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/)
+ * - The docker images tagged spark-test-master and spark-test-worker are built from the
+ * docker/ directory. Run 'docker/spark-test/build' to generate these.
+ */
+private[spark] object FaultToleranceTest extends App with Logging {
+ val masters = ListBuffer[TestMasterInfo]()
+ val workers = ListBuffer[TestWorkerInfo]()
+ var sc: SparkContext = _
+
+ var numPassed = 0
+ var numFailed = 0
+
+ val sparkHome = System.getenv("SPARK_HOME")
+ assertTrue(sparkHome != null, "Run with a valid SPARK_HOME")
+
+ val containerSparkHome = "/opt/spark"
+ val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome)
+
+ System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip
+
+ def afterEach() {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ terminateCluster()
+ }
+
+ test("sanity-basic") {
+ addMasters(1)
+ addWorkers(1)
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("sanity-many-masters") {
+ addMasters(3)
+ addWorkers(3)
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("single-master-halt") {
+ addMasters(3)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ delay(30 seconds)
+ assertValidClusterState()
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("single-master-restart") {
+ addMasters(1)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ addMasters(1)
+ delay(30 seconds)
+ assertValidClusterState()
+
+ killLeader()
+ addMasters(1)
+ delay(30 seconds)
+ assertValidClusterState()
+ }
+
+ test("cluster-failure") {
+ addMasters(2)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ terminateCluster()
+ addMasters(2)
+ addWorkers(2)
+ assertValidClusterState()
+ }
+
+ test("all-but-standby-failure") {
+ addMasters(2)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ workers.foreach(_.kill())
+ workers.clear()
+ delay(30 seconds)
+ addWorkers(2)
+ assertValidClusterState()
+ }
+
+ test("rolling-outage") {
+ addMasters(1)
+ delay()
+ addMasters(1)
+ delay()
+ addMasters(1)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+ assertTrue(getLeader == masters.head)
+
+ (1 to 3).foreach { _ =>
+ killLeader()
+ delay(30 seconds)
+ assertValidClusterState()
+ assertTrue(getLeader == masters.head)
+ addMasters(1)
+ }
+ }
+
+ def test(name: String)(fn: => Unit) {
+ try {
+ fn
+ numPassed += 1
+ logInfo("Passed: " + name)
+ } catch {
+ case e: Exception =>
+ numFailed += 1
+ logError("FAILED: " + name, e)
+ }
+ afterEach()
+ }
+
+ def addMasters(num: Int) {
+ (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
+ }
+
+ def addWorkers(num: Int) {
+ val masterUrls = getMasterUrls(masters)
+ (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
+ }
+
+ /** Creates a SparkContext, which constructs a Client to interact with our cluster. */
+ def createClient() = {
+ if (sc != null) { sc.stop() }
+ // Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this
+ // property, we need to reset it.
+ System.setProperty("spark.driver.port", "0")
+ sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome)
+ }
+
+ def getMasterUrls(masters: Seq[TestMasterInfo]): String = {
+ "spark://" + masters.map(master => master.ip + ":7077").mkString(",")
+ }
+
+ def getLeader: TestMasterInfo = {
+ val leaders = masters.filter(_.state == RecoveryState.ALIVE)
+ assertTrue(leaders.size == 1)
+ leaders(0)
+ }
+
+ def killLeader(): Unit = {
+ masters.foreach(_.readState())
+ val leader = getLeader
+ masters -= leader
+ leader.kill()
+ }
+
+ def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
+
+ def terminateCluster() {
+ masters.foreach(_.kill())
+ workers.foreach(_.kill())
+ masters.clear()
+ workers.clear()
+ }
+
+ /** This includes Client retry logic, so it may take a while if the cluster is recovering. */
+ def assertUsable() = {
+ val f = future {
+ try {
+ val res = sc.parallelize(0 until 10).collect()
+ assertTrue(res.toList == (0 until 10))
+ true
+ } catch {
+ case e: Exception =>
+ logError("assertUsable() had exception", e)
+ e.printStackTrace()
+ false
+ }
+ }
+
+ // Avoid waiting indefinitely (e.g., we could register but get no executors).
+ assertTrue(Await.result(f, 120 seconds))
+ }
+
+ /**
+ * Asserts that the cluster is usable and that the expected masters and workers
+ * are all alive in a proper configuration (e.g., only one leader).
+ */
+ def assertValidClusterState() = {
+ assertUsable()
+ var numAlive = 0
+ var numStandby = 0
+ var numLiveApps = 0
+ var liveWorkerIPs: Seq[String] = List()
+
+ def stateValid(): Boolean = {
+ (workers.map(_.ip) -- liveWorkerIPs).isEmpty &&
+ numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1
+ }
+
+ val f = future {
+ try {
+ while (!stateValid()) {
+ Thread.sleep(1000)
+
+ numAlive = 0
+ numStandby = 0
+ numLiveApps = 0
+
+ masters.foreach(_.readState())
+
+ for (master <- masters) {
+ master.state match {
+ case RecoveryState.ALIVE =>
+ numAlive += 1
+ liveWorkerIPs = master.liveWorkerIPs
+ case RecoveryState.STANDBY =>
+ numStandby += 1
+ case _ => // ignore
+ }
+
+ numLiveApps += master.numLiveApps
+ }
+ }
+ true
+ } catch {
+ case e: Exception =>
+ logError("assertValidClusterState() had exception", e)
+ false
+ }
+ }
+
+ try {
+ assertTrue(Await.result(f, 120 seconds))
+ } catch {
+ case e: TimeoutException =>
+ logError("Master states: " + masters.map(_.state))
+ logError("Num apps: " + numLiveApps)
+ logError("IPs expected: " + workers.map(_.ip) + " / found: " + liveWorkerIPs)
+ throw new RuntimeException("Failed to get into acceptable cluster state after 2 min.", e)
+ }
+ }
+
+ def assertTrue(bool: Boolean, message: String = "") {
+ if (!bool) {
+ throw new IllegalStateException("Assertion failed: " + message)
+ }
+ }
+
+ logInfo("Ran %s tests, %s passed and %s failed".format(numPassed+numFailed, numPassed, numFailed))
+}
+
+private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File)
+ extends Logging {
+
+ implicit val formats = net.liftweb.json.DefaultFormats
+ var state: RecoveryState.Value = _
+ var liveWorkerIPs: List[String] = _
+ var numLiveApps = 0
+
+ logDebug("Created master: " + this)
+
+ def readState() {
+ try {
+ val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream)
+ val json = JsonParser.parse(masterStream, closeAutomatically = true)
+
+ val workers = json \ "workers"
+ val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE")
+ liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String])
+
+ numLiveApps = (json \ "activeapps").children.size
+
+ val status = json \\ "status"
+ val stateString = status.extract[String]
+ state = RecoveryState.values.filter(state => state.toString == stateString).head
+ } catch {
+ case e: Exception =>
+ // ignore, no state update
+ logWarning("Exception", e)
+ }
+ }
+
+ def kill() { Docker.kill(dockerId) }
+
+ override def toString: String =
+ "[ip=%s, id=%s, logFile=%s, state=%s]".
+ format(ip, dockerId.id, logFile.getAbsolutePath, state)
+}
+
+private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File)
+ extends Logging {
+
+ implicit val formats = net.liftweb.json.DefaultFormats
+
+ logDebug("Created worker: " + this)
+
+ def kill() { Docker.kill(dockerId) }
+
+ override def toString: String =
+ "[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath)
+}
+
+private[spark] object SparkDocker {
+ def startMaster(mountDir: String): TestMasterInfo = {
+ val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir)
+ val (ip, id, outFile) = startNode(cmd)
+ new TestMasterInfo(ip, id, outFile)
+ }
+
+ def startWorker(mountDir: String, masters: String): TestWorkerInfo = {
+ val cmd = Docker.makeRunCmd("spark-test-worker", args = masters, mountDir = mountDir)
+ val (ip, id, outFile) = startNode(cmd)
+ new TestWorkerInfo(ip, id, outFile)
+ }
+
+ private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = {
+ val ipPromise = promise[String]()
+ val outFile = File.createTempFile("fault-tolerance-test", "")
+ outFile.deleteOnExit()
+ val outStream: FileWriter = new FileWriter(outFile)
+ def findIpAndLog(line: String): Unit = {
+ if (line.startsWith("CONTAINER_IP=")) {
+ val ip = line.split("=")(1)
+ ipPromise.success(ip)
+ }
+
+ outStream.write(line + "\n")
+ outStream.flush()
+ }
+
+ dockerCmd.run(ProcessLogger(findIpAndLog _))
+ val ip = Await.result(ipPromise.future, 30 seconds)
+ val dockerId = Docker.getLastProcessId
+ (ip, dockerId, outFile)
+ }
+}
+
+private[spark] class DockerId(val id: String) {
+ override def toString = id
+}
+
+private[spark] object Docker extends Logging {
+ def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = {
+ val mountCmd = if (mountDir != "") { " -v " + mountDir } else ""
+
+ val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args)
+ logDebug("Run command: " + cmd)
+ cmd
+ }
+
+ def kill(dockerId: DockerId) : Unit = {
+ "docker kill %s".format(dockerId.id).!
+ }
+
+ def getLastProcessId: DockerId = {
+ var id: String = null
+ "docker ps -l -q".!(ProcessLogger(line => id = line))
+ new DockerId(id)
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index 04d01c169d..e607b8c6f4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -72,7 +72,8 @@ private[spark] object JsonProtocol {
("memory" -> obj.workers.map(_.memory).sum) ~
("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
- ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo))
+ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~
+ ("status" -> obj.status.toString)
}
def writeWorkerState(obj: WorkerStateResponse) = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 6a7d5a85ba..94cf4ff88b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -39,22 +39,23 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
- def start(): String = {
+ def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
+ val masters = Array(masterUrl)
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
- memoryPerWorker, masterUrl, null, Some(workerNum))
+ memoryPerWorker, masters, null, Some(workerNum))
workerActorSystems += workerSystem
}
- return masterUrl
+ return masters
}
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 993ba6bd3d..c29a30184a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,28 +17,59 @@
package org.apache.spark.deploy
-import com.google.common.collect.MapMaker
+import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.spark.SparkException
/**
- * Contains util methods to interact with Hadoop from spark.
+ * Contains util methods to interact with Hadoop from Spark.
*/
+private[spark]
class SparkHadoopUtil {
- // A general, soft-reference map for metadata needed during HadoopRDD split computation
- // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
- private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
+ val conf = newConfiguration()
+ UserGroupInformation.setConfiguration(conf)
- // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
- // subsystems
+ def runAsUser(user: String)(func: () => Unit) {
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
+ }
+
+ /**
+ * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
+ * subsystems.
+ */
def newConfiguration(): Configuration = new Configuration()
- // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
- // cluster
+ /**
+ * Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ * cluster.
+ */
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+}
+
+object SparkHadoopUtil {
+ private val hadoop = {
+ val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
+ if (yarnMode) {
+ try {
+ Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
+ } catch {
+ case th: Throwable => throw new SparkException("Unable to load YARN support", th)
+ }
+ } else {
+ new SparkHadoopUtil
+ }
+ }
+ def get: SparkHadoopUtil = {
+ hadoop
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
index 164386782c..be8693ec54 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.concurrent.Await
+import scala.concurrent.ExecutionContext.Implicits.global
import akka.actor._
import akka.actor.Terminated
@@ -37,41 +38,81 @@ import org.apache.spark.deploy.master.Master
/**
* The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
* and a listener for cluster events, and calls back the listener when various events occur.
+ *
+ * @param masterUrls Each url should look like spark://host:port.
*/
private[spark] class Client(
actorSystem: ActorSystem,
- masterUrl: String,
+ masterUrls: Array[String],
appDescription: ApplicationDescription,
listener: ClientListener)
extends Logging {
+ val REGISTRATION_TIMEOUT = 20.seconds
+ val REGISTRATION_RETRIES = 3
+
var actor: ActorRef = null
var appId: String = null
+ var registered = false
+ var activeMasterUrl: String = null
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
+ var alreadyDead = false // To avoid calling listener.dead() multiple times
override def preStart() {
- logInfo("Connecting to master " + masterUrl)
try {
- master = context.actorFor(Master.toAkkaUrl(masterUrl))
- masterAddress = master.path.address
- master ! RegisterApplication(appDescription)
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ registerWithMaster()
} catch {
case e: Exception =>
- logError("Failed to connect to master", e)
+ logWarning("Failed to connect to master", e)
markDisconnected()
context.stop(self)
}
}
+ def tryRegisterAllMasters() {
+ for (masterUrl <- masterUrls) {
+ logInfo("Connecting to master " + masterUrl + "...")
+ val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ actor ! RegisterApplication(appDescription)
+ }
+ }
+
+ def registerWithMaster() {
+ tryRegisterAllMasters()
+
+ var retries = 0
+ lazy val retryTimer: Cancellable =
+ context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ retries += 1
+ if (registered) {
+ retryTimer.cancel()
+ } else if (retries >= REGISTRATION_RETRIES) {
+ logError("All masters are unresponsive! Giving up.")
+ markDead()
+ } else {
+ tryRegisterAllMasters()
+ }
+ }
+ retryTimer // start timer
+ }
+
+ def changeMaster(url: String) {
+ activeMasterUrl = url
+ master = context.actorFor(Master.toAkkaUrl(url))
+ masterAddress = master.path.address
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ }
+
override def receive = {
- case RegisteredApplication(appId_) =>
+ case RegisteredApplication(appId_, masterUrl) =>
appId = appId_
+ registered = true
+ changeMaster(masterUrl)
listener.connected(appId)
case ApplicationRemoved(message) =>
@@ -92,23 +133,27 @@ private[spark] class Client(
listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
}
+ case MasterChanged(masterUrl, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterUrl)
+ context.unwatch(master)
+ changeMaster(masterUrl)
+ alreadyDisconnected = false
+ sender ! MasterChangeAcknowledged(appId)
+
case Terminated(actor_) if actor_ == master =>
- logError("Connection to master failed; stopping client")
+ logWarning("Connection to master failed; waiting for master to reconnect...")
markDisconnected()
- context.stop(self)
case DisassociatedEvent(_, address, _) if address == masterAddress =>
logError("Connection to master failed; stopping client")
markDisconnected()
- context.stop(self)
case AssociationErrorEvent(_, _, address, _) if address == masterAddress =>
logError("Connection to master failed; stopping client")
markDisconnected()
- context.stop(self)
case StopClient =>
- markDisconnected()
+ markDead()
sender ! true
context.stop(self)
}
@@ -122,6 +167,13 @@ private[spark] class Client(
alreadyDisconnected = true
}
}
+
+ def markDead() {
+ if (!alreadyDead) {
+ listener.dead()
+ alreadyDead = true
+ }
+ }
}
def start() {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
index 4605368c11..be7a11bd15 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
@@ -27,8 +27,12 @@ package org.apache.spark.deploy.client
private[spark] trait ClientListener {
def connected(appId: String): Unit
+ /** Disconnection may be a temporary state, as we fail over to a new Master. */
def disconnected(): Unit
+ /** Dead means that we couldn't find any Masters to connect to, and have given up. */
+ def dead(): Unit
+
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index d5e9a0e095..5b62d3ba6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -33,6 +33,11 @@ private[spark] object TestClient {
System.exit(0)
}
+ def dead() {
+ logInfo("Could not connect to master")
+ System.exit(0)
+ }
+
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
@@ -44,7 +49,7 @@ private[spark] object TestClient {
val desc = new ApplicationDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
val listener = new TestListener
- val client = new Client(actorSystem, url, desc, listener)
+ val client = new Client(actorSystem, Array(url), desc, listener)
client.start()
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index bd5327627a..5150b7c7de 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -29,23 +29,46 @@ private[spark] class ApplicationInfo(
val submitDate: Date,
val driver: ActorRef,
val appUiUrl: String)
-{
- var state = ApplicationState.WAITING
- var executors = new mutable.HashMap[Int, ExecutorInfo]
- var coresGranted = 0
- var endTime = -1L
- val appSource = new ApplicationSource(this)
-
- private var nextExecutorId = 0
-
- def newExecutorId(): Int = {
- val id = nextExecutorId
- nextExecutorId += 1
- id
+ extends Serializable {
+
+ @transient var state: ApplicationState.Value = _
+ @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
+ @transient var coresGranted: Int = _
+ @transient var endTime: Long = _
+ @transient var appSource: ApplicationSource = _
+
+ @transient private var nextExecutorId: Int = _
+
+ init()
+
+ private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ in.defaultReadObject()
+ init()
+ }
+
+ private def init() {
+ state = ApplicationState.WAITING
+ executors = new mutable.HashMap[Int, ExecutorInfo]
+ coresGranted = 0
+ endTime = -1L
+ appSource = new ApplicationSource(this)
+ nextExecutorId = 0
+ }
+
+ private def newExecutorId(useID: Option[Int] = None): Int = {
+ useID match {
+ case Some(id) =>
+ nextExecutorId = math.max(nextExecutorId, id + 1)
+ id
+ case None =>
+ val id = nextExecutorId
+ nextExecutorId += 1
+ id
+ }
}
- def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = {
- val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave)
+ def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
+ val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
executors(exec.id) = exec
coresGranted += cores
exec
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index 39ef090ddf..a74d7be4c9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -22,7 +22,7 @@ private[spark] object ApplicationState
type ApplicationState = Value
- val WAITING, RUNNING, FINISHED, FAILED = Value
+ val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value
val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
index cf384a985e..76db61dd61 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.master
-import org.apache.spark.deploy.ExecutorState
+import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
private[spark] class ExecutorInfo(
val id: Int,
@@ -28,5 +28,10 @@ private[spark] class ExecutorInfo(
var state = ExecutorState.LAUNCHING
+ /** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
+ def copyState(execDesc: ExecutorDescription) {
+ state = execDesc.state
+ }
+
def fullId: String = application.id + "/" + id
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
new file mode 100644
index 0000000000..043945a211
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import java.io._
+
+import scala.Serializable
+
+import akka.serialization.Serialization
+import org.apache.spark.Logging
+
+/**
+ * Stores data in a single on-disk directory with one file per application and worker.
+ * Files are deleted when applications and workers are removed.
+ *
+ * @param dir Directory to store files. Created if non-existent (but not recursively).
+ * @param serialization Used to serialize our objects.
+ */
+private[spark] class FileSystemPersistenceEngine(
+ val dir: String,
+ val serialization: Serialization)
+ extends PersistenceEngine with Logging {
+
+ new File(dir).mkdir()
+
+ override def addApplication(app: ApplicationInfo) {
+ val appFile = new File(dir + File.separator + "app_" + app.id)
+ serializeIntoFile(appFile, app)
+ }
+
+ override def removeApplication(app: ApplicationInfo) {
+ new File(dir + File.separator + "app_" + app.id).delete()
+ }
+
+ override def addWorker(worker: WorkerInfo) {
+ val workerFile = new File(dir + File.separator + "worker_" + worker.id)
+ serializeIntoFile(workerFile, worker)
+ }
+
+ override def removeWorker(worker: WorkerInfo) {
+ new File(dir + File.separator + "worker_" + worker.id).delete()
+ }
+
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
+ val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ (apps, workers)
+ }
+
+ private def serializeIntoFile(file: File, value: AnyRef) {
+ val created = file.createNewFile()
+ if (!created) { throw new IllegalStateException("Could not create file: " + file) }
+
+ val serializer = serialization.findSerializerFor(value)
+ val serialized = serializer.toBinary(value)
+
+ val out = new FileOutputStream(file)
+ out.write(serialized)
+ out.close()
+ }
+
+ def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
+ val fileData = new Array[Byte](file.length().asInstanceOf[Int])
+ val dis = new DataInputStream(new FileInputStream(file))
+ dis.readFully(fileData)
+ dis.close()
+
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
+ val serializer = serialization.serializerFor(clazz)
+ serializer.fromBinary(fileData).asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
new file mode 100644
index 0000000000..f25a1ad3bf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.actor.{Actor, ActorRef}
+
+import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
+
+/**
+ * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
+ * is the only Master serving requests.
+ * In addition to the API provided, the LeaderElectionAgent will use of the following messages
+ * to inform the Master of leader changes:
+ * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
+ * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
+ */
+private[spark] trait LeaderElectionAgent extends Actor {
+ val masterActor: ActorRef
+}
+
+/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
+private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
+ override def preStart() {
+ masterActor ! ElectedLeader
+ }
+
+ override def receive = {
+ case _ =>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index cb0fe6a850..26f980760d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -23,23 +23,25 @@ import java.text.SimpleDateFormat
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Await
import scala.concurrent.duration._
+import scala.concurrent.duration.{ Duration, FiniteDuration }
+import scala.concurrent.ExecutionContext.Implicits.global
import akka.actor._
import akka.pattern.ask
import akka.remote._
+import akka.util.Timeout
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
-import akka.util.Timeout
import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
import org.apache.spark.deploy.DeployMessages.KillExecutor
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
import scala.Some
-import org.apache.spark.deploy.DeployMessages.WebUIPortResponse
import org.apache.spark.deploy.DeployMessages.LaunchExecutor
import org.apache.spark.deploy.DeployMessages.RegisteredApplication
import org.apache.spark.deploy.DeployMessages.RegisterWorker
@@ -51,6 +53,8 @@ import org.apache.spark.deploy.DeployMessages.ApplicationRemoved
import org.apache.spark.deploy.DeployMessages.Heartbeat
import org.apache.spark.deploy.DeployMessages.RegisteredWorker
import akka.actor.Terminated
+import akka.serialization.SerializationExtension
+import java.util.concurrent.TimeUnit
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
@@ -58,7 +62,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt
-
+ val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "")
+ val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE")
+
var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
val idToWorker = new HashMap[String, WorkerInfo]
@@ -88,52 +94,115 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (envVar != null) envVar else host
}
+ val masterUrl = "spark://" + host + ":" + port
+ var masterWebUiUrl: String = _
+
+ var state = RecoveryState.STANDBY
+
+ var persistenceEngine: PersistenceEngine = _
+
+ var leaderElectionAgent: ActorRef = _
+
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
- logInfo("Starting Spark master at spark://" + host + ":" + port)
+ logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
- import context.dispatcher
+ masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
applicationMetricsSystem.start()
+
+ persistenceEngine = RECOVERY_MODE match {
+ case "ZOOKEEPER" =>
+ logInfo("Persisting recovery state to ZooKeeper")
+ new ZooKeeperPersistenceEngine(SerializationExtension(context.system))
+ case "FILESYSTEM" =>
+ logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
+ new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
+ case _ =>
+ new BlackHolePersistenceEngine()
+ }
+
+ leaderElectionAgent = RECOVERY_MODE match {
+ case "ZOOKEEPER" =>
+ context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl))
+ case _ =>
+ context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
+ }
+ }
+
+ override def preRestart(reason: Throwable, message: Option[Any]) {
+ super.preRestart(reason, message) // calls postStop()!
+ logError("Master actor restarted due to exception", reason)
}
override def postStop() {
webUi.stop()
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
+ persistenceEngine.close()
+ context.stop(leaderElectionAgent)
}
override def receive = {
- case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => {
+ case ElectedLeader => {
+ val (storedApps, storedWorkers) = persistenceEngine.readPersistedData()
+ state = if (storedApps.isEmpty && storedWorkers.isEmpty)
+ RecoveryState.ALIVE
+ else
+ RecoveryState.RECOVERING
+
+ logInfo("I have been elected leader! New state: " + state)
+
+ if (state == RecoveryState.RECOVERING) {
+ beginRecovery(storedApps, storedWorkers)
+ context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
+ }
+ }
+
+ case RevokedLeadership => {
+ logError("Leadership has been revoked -- master shutting down.")
+ System.exit(0)
+ }
+
+ case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
host, workerPort, cores, Utils.megabytesToString(memory)))
- if (idToWorker.contains(id)) {
+ if (state == RecoveryState.STANDBY) {
+ // ignore, don't send response
+ } else if (idToWorker.contains(id)) {
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
- addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress)
+ val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+ registerWorker(worker)
context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get)
+ persistenceEngine.addWorker(worker)
+ sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
schedule()
}
}
case RegisterApplication(description) => {
- logInfo("Registering app " + description.name)
- val app = addApplication(description, sender)
- logInfo("Registered app " + description.name + " with ID " + app.id)
- waitingApps += app
- context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredApplication(app.id)
- schedule()
+ if (state == RecoveryState.STANDBY) {
+ // ignore, don't send response
+ } else {
+ logInfo("Registering app " + description.name)
+ val app = createApplication(description, sender)
+ registerApplication(app)
+ logInfo("Registered app " + description.name + " with ID " + app.id)
+ context.watch(sender) // This doesn't work with remote actors but helps for testing
+ persistenceEngine.addApplication(app)
+ sender ! RegisteredApplication(app.id, masterUrl)
+ schedule()
+ }
}
case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
@@ -173,27 +242,63 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ case MasterChangeAcknowledged(appId) => {
+ idToApp.get(appId) match {
+ case Some(app) =>
+ logInfo("Application has been re-registered: " + appId)
+ app.state = ApplicationState.WAITING
+ case None =>
+ logWarning("Master change ack from unknown app: " + appId)
+ }
+
+ if (canCompleteRecovery) { completeRecovery() }
+ }
+
+ case WorkerSchedulerStateResponse(workerId, executors) => {
+ idToWorker.get(workerId) match {
+ case Some(worker) =>
+ logInfo("Worker has been re-registered: " + workerId)
+ worker.state = WorkerState.ALIVE
+
+ val validExecutors = executors.filter(exec => idToApp.get(exec.appId).isDefined)
+ for (exec <- validExecutors) {
+ val app = idToApp.get(exec.appId).get
+ val execInfo = app.addExecutor(worker, exec.cores, Some(exec.execId))
+ worker.addExecutor(execInfo)
+ execInfo.copyState(exec)
+ }
+ case None =>
+ logWarning("Scheduler state from unknown worker: " + workerId)
+ }
+
+ if (canCompleteRecovery) { completeRecovery() }
+ }
+
case Terminated(actor) => {
// The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
actorToApp.get(actor).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case DisassociatedEvent(_, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case AssociationErrorEvent(_, _, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case RequestMasterState => {
- sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray)
+ sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
+ state)
}
case CheckForWorkerTimeOut => {
@@ -205,6 +310,50 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ def canCompleteRecovery =
+ workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
+ apps.count(_.state == ApplicationState.UNKNOWN) == 0
+
+ def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) {
+ for (app <- storedApps) {
+ logInfo("Trying to recover app: " + app.id)
+ try {
+ registerApplication(app)
+ app.state = ApplicationState.UNKNOWN
+ app.driver ! MasterChanged(masterUrl, masterWebUiUrl)
+ } catch {
+ case e: Exception => logInfo("App " + app.id + " had exception on reconnect")
+ }
+ }
+
+ for (worker <- storedWorkers) {
+ logInfo("Trying to recover worker: " + worker.id)
+ try {
+ registerWorker(worker)
+ worker.state = WorkerState.UNKNOWN
+ worker.actor ! MasterChanged(masterUrl, masterWebUiUrl)
+ } catch {
+ case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect")
+ }
+ }
+ }
+
+ def completeRecovery() {
+ // Ensure "only-once" recovery semantics using a short synchronization period.
+ synchronized {
+ if (state != RecoveryState.RECOVERING) { return }
+ state = RecoveryState.COMPLETING_RECOVERY
+ }
+
+ // Kill off any workers and apps that didn't respond to us.
+ workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
+ apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
+
+ state = RecoveryState.ALIVE
+ schedule()
+ logInfo("Recovery complete - resuming operations!")
+ }
+
/**
* Can an app use the given worker? True if the worker has enough memory and we haven't already
* launched an executor for the app on it (right now the standalone backend doesn't like having
@@ -219,6 +368,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
* every time a new app joins or resource availability changes.
*/
def schedule() {
+ if (state != RecoveryState.ALIVE) { return }
// Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
// in the queue, then the second app, etc.
if (spreadOutApps) {
@@ -266,14 +416,13 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(
+ worker.actor ! LaunchExecutor(masterUrl,
exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
exec.application.driver ! ExecutorAdded(
exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
- def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
- publicAddress: String): WorkerInfo = {
+ def registerWorker(worker: WorkerInfo): Unit = {
// There may be one or more refs to dead workers on this same node (w/ different ID's),
// remove them.
workers.filter { w =>
@@ -281,12 +430,17 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}.foreach { w =>
workers -= w
}
- val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+
+ val workerAddress = worker.actor.path.address
+ if (addressToWorker.contains(workerAddress)) {
+ logInfo("Attempted to re-register worker at same address: " + workerAddress)
+ return
+ }
+
workers += worker
idToWorker(worker.id) = worker
- actorToWorker(sender) = worker
- addressToWorker(sender.path.address) = worker
- worker
+ actorToWorker(worker.actor) = worker
+ addressToWorker(workerAddress) = worker
}
def removeWorker(worker: WorkerInfo) {
@@ -301,25 +455,36 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
exec.id, ExecutorState.LOST, Some("worker lost"), None)
exec.application.removeExecutor(exec)
}
+ persistenceEngine.removeWorker(worker)
}
- def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
+ def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ }
+
+ def registerApplication(app: ApplicationInfo): Unit = {
+ val appAddress = app.driver.path.address
+ if (addressToWorker.contains(appAddress)) {
+ logInfo("Attempted to re-register application at same address: " + appAddress)
+ return
+ }
+
applicationMetricsSystem.registerSource(app.appSource)
apps += app
idToApp(app.id) = app
- actorToApp(driver) = app
- addressToApp(driver.path.address) = app
+ actorToApp(app.driver) = app
+ addressToApp(appAddress) = app
if (firstApp == None) {
firstApp = Some(app)
}
+ // TODO: What is firstApp?? Can we remove it?
val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
- if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) {
+ if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) {
logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
}
- app
+ waitingApps += app
}
def finishApplication(app: ApplicationInfo) {
@@ -344,13 +509,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
waitingApps -= app
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
- exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
+ exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
exec.state = ExecutorState.KILLED
}
app.markFinished(state)
if (state != ApplicationState.FINISHED) {
app.driver ! ApplicationRemoved(state.toString)
}
+ persistenceEngine.removeApplication(app)
schedule()
}
}
@@ -404,8 +570,8 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName)
- val timeoutDuration = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeoutDuration : FiniteDuration = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS)
implicit val timeout = Timeout(timeoutDuration)
val respFuture = actor ? RequestWebUIPort // ask pattern
val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
new file mode 100644
index 0000000000..74a9f8cd82
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+sealed trait MasterMessages extends Serializable
+
+/** Contains messages seen only by the Master and its associated entities. */
+private[master] object MasterMessages {
+
+ // LeaderElectionAgent to Master
+
+ case object ElectedLeader
+
+ case object RevokedLeadership
+
+ // Actor System to LeaderElectionAgent
+
+ case object CheckLeader
+
+ // Actor System to Master
+
+ case object CheckForWorkerTimeOut
+
+ case class BeginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo])
+
+ case object CompleteRecovery
+
+ case object RequestWebUIPort
+
+ case class WebUIPortResponse(webUIBoundPort: Int)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
new file mode 100644
index 0000000000..94b986caf2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+/**
+ * Allows Master to persist any state that is necessary in order to recover from a failure.
+ * The following semantics are required:
+ * - addApplication and addWorker are called before completing registration of a new app/worker.
+ * - removeApplication and removeWorker are called at any time.
+ * Given these two requirements, we will have all apps and workers persisted, but
+ * we might not have yet deleted apps or workers that finished (so their liveness must be verified
+ * during recovery).
+ */
+private[spark] trait PersistenceEngine {
+ def addApplication(app: ApplicationInfo)
+
+ def removeApplication(app: ApplicationInfo)
+
+ def addWorker(worker: WorkerInfo)
+
+ def removeWorker(worker: WorkerInfo)
+
+ /**
+ * Returns the persisted data sorted by their respective ids (which implies that they're
+ * sorted by time of creation).
+ */
+ def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo])
+
+ def close() {}
+}
+
+private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
+ override def addApplication(app: ApplicationInfo) {}
+ override def removeApplication(app: ApplicationInfo) {}
+ override def addWorker(worker: WorkerInfo) {}
+ override def removeWorker(worker: WorkerInfo) {}
+ override def readPersistedData() = (Nil, Nil)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
new file mode 100644
index 0000000000..b91be821f0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+private[spark] object RecoveryState
+ extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
+
+ type MasterState = Value
+
+ val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
new file mode 100644
index 0000000000..81e15c534f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import scala.collection.JavaConversions._
+import scala.concurrent.ops._
+
+import org.apache.spark.Logging
+import org.apache.zookeeper._
+import org.apache.zookeeper.data.Stat
+import org.apache.zookeeper.Watcher.Event.KeeperState
+
+/**
+ * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
+ * logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be
+ * created. If ZooKeeper remains down after several retries, the given
+ * [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be
+ * informed via zkDown().
+ *
+ * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
+ * times or a semantic exception is thrown (e.g.., "node already exists").
+ */
+private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging {
+ val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "")
+
+ val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
+ val ZK_TIMEOUT_MILLIS = 30000
+ val RETRY_WAIT_MILLIS = 5000
+ val ZK_CHECK_PERIOD_MILLIS = 10000
+ val MAX_RECONNECT_ATTEMPTS = 3
+
+ private var zk: ZooKeeper = _
+
+ private val watcher = new ZooKeeperWatcher()
+ private var reconnectAttempts = 0
+ private var closed = false
+
+ /** Connect to ZooKeeper to start the session. Must be called before anything else. */
+ def connect() {
+ connectToZooKeeper()
+
+ new Thread() {
+ override def run() = sessionMonitorThread()
+ }.start()
+ }
+
+ def sessionMonitorThread(): Unit = {
+ while (!closed) {
+ Thread.sleep(ZK_CHECK_PERIOD_MILLIS)
+ if (zk.getState != ZooKeeper.States.CONNECTED) {
+ reconnectAttempts += 1
+ val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts
+ if (attemptsLeft <= 0) {
+ logError("Could not connect to ZooKeeper: system failure")
+ zkWatcher.zkDown()
+ close()
+ } else {
+ logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...")
+ connectToZooKeeper()
+ }
+ }
+ }
+ }
+
+ def close() {
+ if (!closed && zk != null) { zk.close() }
+ closed = true
+ }
+
+ private def connectToZooKeeper() {
+ if (zk != null) zk.close()
+ zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher)
+ }
+
+ /**
+ * Attempts to maintain a live ZooKeeper exception despite (very) transient failures.
+ * Mainly useful for handling the natural ZooKeeper session expiration.
+ */
+ private class ZooKeeperWatcher extends Watcher {
+ def process(event: WatchedEvent) {
+ if (closed) { return }
+
+ event.getState match {
+ case KeeperState.SyncConnected =>
+ reconnectAttempts = 0
+ zkWatcher.zkSessionCreated()
+ case KeeperState.Expired =>
+ connectToZooKeeper()
+ case KeeperState.Disconnected =>
+ logWarning("ZooKeeper disconnected, will retry...")
+ }
+ }
+ }
+
+ def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = {
+ retry {
+ zk.create(path, bytes, ZK_ACL, createMode)
+ }
+ }
+
+ def exists(path: String, watcher: Watcher = null): Stat = {
+ retry {
+ zk.exists(path, watcher)
+ }
+ }
+
+ def getChildren(path: String, watcher: Watcher = null): List[String] = {
+ retry {
+ zk.getChildren(path, watcher).toList
+ }
+ }
+
+ def getData(path: String): Array[Byte] = {
+ retry {
+ zk.getData(path, false, null)
+ }
+ }
+
+ def delete(path: String, version: Int = -1): Unit = {
+ retry {
+ zk.delete(path, version)
+ }
+ }
+
+ /**
+ * Creates the given directory (non-recursively) if it doesn't exist.
+ * All znodes are created in PERSISTENT mode with no data.
+ */
+ def mkdir(path: String) {
+ if (exists(path) == null) {
+ try {
+ create(path, "".getBytes, CreateMode.PERSISTENT)
+ } catch {
+ case e: Exception =>
+ // If the exception caused the directory not to be created, bubble it up,
+ // otherwise ignore it.
+ if (exists(path) == null) { throw e }
+ }
+ }
+ }
+
+ /**
+ * Recursively creates all directories up to the given one.
+ * All znodes are created in PERSISTENT mode with no data.
+ */
+ def mkdirRecursive(path: String) {
+ var fullDir = ""
+ for (dentry <- path.split("/").tail) {
+ fullDir += "/" + dentry
+ mkdir(fullDir)
+ }
+ }
+
+ /**
+ * Retries the given function up to 3 times. The assumption is that failure is transient,
+ * UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist),
+ * in which case the exception will be thrown without retries.
+ *
+ * @param fn Block to execute, possibly multiple times.
+ */
+ def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = {
+ try {
+ fn
+ } catch {
+ case e: KeeperException.NoNodeException => throw e
+ case e: KeeperException.NodeExistsException => throw e
+ case e if n > 0 =>
+ logError("ZooKeeper exception, " + n + " more retries...", e)
+ Thread.sleep(RETRY_WAIT_MILLIS)
+ retry(fn, n-1)
+ }
+ }
+}
+
+trait SparkZooKeeperWatcher {
+ /**
+ * Called whenever a ZK session is created --
+ * this will occur when we create our first session as well as each time
+ * the session expires or errors out.
+ */
+ def zkSessionCreated()
+
+ /**
+ * Called if ZK appears to be completely down (i.e., not just a transient error).
+ * We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead.
+ */
+ def zkDown()
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 6219f11f2a..e05f587b58 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -22,28 +22,44 @@ import scala.collection.mutable
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
- val id: String,
- val host: String,
- val port: Int,
- val cores: Int,
- val memory: Int,
- val actor: ActorRef,
- val webUiPort: Int,
- val publicAddress: String) {
+ val id: String,
+ val host: String,
+ val port: Int,
+ val cores: Int,
+ val memory: Int,
+ val actor: ActorRef,
+ val webUiPort: Int,
+ val publicAddress: String)
+ extends Serializable {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
- var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
- var state: WorkerState.Value = WorkerState.ALIVE
- var coresUsed = 0
- var memoryUsed = 0
+ @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info
+ @transient var state: WorkerState.Value = _
+ @transient var coresUsed: Int = _
+ @transient var memoryUsed: Int = _
- var lastHeartbeat = System.currentTimeMillis()
+ @transient var lastHeartbeat: Long = _
+
+ init()
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
+ private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ in.defaultReadObject()
+ init()
+ }
+
+ private def init() {
+ executors = new mutable.HashMap
+ state = WorkerState.ALIVE
+ coresUsed = 0
+ memoryUsed = 0
+ lastHeartbeat = System.currentTimeMillis()
+ }
+
def hostPort: String = {
assert (port > 0)
host + ":" + port
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
index fb3fe88d92..0b36ef6005 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
@@ -20,5 +20,5 @@ package org.apache.spark.deploy.master
private[spark] object WorkerState extends Enumeration {
type WorkerState = Value
- val ALIVE, DEAD, DECOMMISSIONED = Value
+ val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
new file mode 100644
index 0000000000..7809013e83
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.actor.ActorRef
+import org.apache.zookeeper._
+import org.apache.zookeeper.Watcher.Event.EventType
+
+import org.apache.spark.deploy.master.MasterMessages._
+import org.apache.spark.Logging
+
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String)
+ extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
+
+ val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
+
+ private val watcher = new ZooKeeperWatcher()
+ private val zk = new SparkZooKeeperSession(this)
+ private var status = LeadershipStatus.NOT_LEADER
+ private var myLeaderFile: String = _
+ private var leaderUrl: String = _
+
+ override def preStart() {
+ logInfo("Starting ZooKeeper LeaderElection agent")
+ zk.connect()
+ }
+
+ override def zkSessionCreated() {
+ synchronized {
+ zk.mkdirRecursive(WORKING_DIR)
+ myLeaderFile =
+ zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL)
+ self ! CheckLeader
+ }
+ }
+
+ override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
+ logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason)
+ Thread.sleep(zk.ZK_TIMEOUT_MILLIS)
+ super.preRestart(reason, message)
+ }
+
+ override def zkDown() {
+ logError("ZooKeeper down! LeaderElectionAgent shutting down Master.")
+ System.exit(1)
+ }
+
+ override def postStop() {
+ zk.close()
+ }
+
+ override def receive = {
+ case CheckLeader => checkLeader()
+ }
+
+ private class ZooKeeperWatcher extends Watcher {
+ def process(event: WatchedEvent) {
+ if (event.getType == EventType.NodeDeleted) {
+ logInfo("Leader file disappeared, a master is down!")
+ self ! CheckLeader
+ }
+ }
+ }
+
+ /** Uses ZK leader election. Navigates several ZK potholes along the way. */
+ def checkLeader() {
+ val masters = zk.getChildren(WORKING_DIR).toList
+ val leader = masters.sorted.head
+ val leaderFile = WORKING_DIR + "/" + leader
+
+ // Setup a watch for the current leader.
+ zk.exists(leaderFile, watcher)
+
+ try {
+ leaderUrl = new String(zk.getData(leaderFile))
+ } catch {
+ // A NoNodeException may be thrown if old leader died since the start of this method call.
+ // This is fine -- just check again, since we're guaranteed to see the new values.
+ case e: KeeperException.NoNodeException =>
+ logInfo("Leader disappeared while reading it -- finding next leader")
+ checkLeader()
+ return
+ }
+
+ // Synchronization used to ensure no interleaving between the creation of a new session and the
+ // checking of a leader, which could cause us to delete our real leader file erroneously.
+ synchronized {
+ val isLeader = myLeaderFile == leaderFile
+ if (!isLeader && leaderUrl == masterUrl) {
+ // We found a different master file pointing to this process.
+ // This can happen in the following two cases:
+ // (1) The master process was restarted on the same node.
+ // (2) The ZK server died between creating the node and returning the name of the node.
+ // For this case, we will end up creating a second file, and MUST explicitly delete the
+ // first one, since our ZK session is still open.
+ // Note that this deletion will cause a NodeDeleted event to be fired so we check again for
+ // leader changes.
+ assert(leaderFile < myLeaderFile)
+ logWarning("Cleaning up old ZK master election file that points to this master.")
+ zk.delete(leaderFile)
+ } else {
+ updateLeadershipStatus(isLeader)
+ }
+ }
+ }
+
+ def updateLeadershipStatus(isLeader: Boolean) {
+ if (isLeader && status == LeadershipStatus.NOT_LEADER) {
+ status = LeadershipStatus.LEADER
+ masterActor ! ElectedLeader
+ } else if (!isLeader && status == LeadershipStatus.LEADER) {
+ status = LeadershipStatus.NOT_LEADER
+ masterActor ! RevokedLeadership
+ }
+ }
+
+ private object LeadershipStatus extends Enumeration {
+ type LeadershipStatus = Value
+ val LEADER, NOT_LEADER = Value
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
new file mode 100644
index 0000000000..825344b3bb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.Logging
+import org.apache.zookeeper._
+
+import akka.serialization.Serialization
+
+class ZooKeeperPersistenceEngine(serialization: Serialization)
+ extends PersistenceEngine
+ with SparkZooKeeperWatcher
+ with Logging
+{
+ val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
+
+ val zk = new SparkZooKeeperSession(this)
+
+ zk.connect()
+
+ override def zkSessionCreated() {
+ zk.mkdirRecursive(WORKING_DIR)
+ }
+
+ override def zkDown() {
+ logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.")
+ }
+
+ override def addApplication(app: ApplicationInfo) {
+ serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
+ }
+
+ override def removeApplication(app: ApplicationInfo) {
+ zk.delete(WORKING_DIR + "/app_" + app.id)
+ }
+
+ override def addWorker(worker: WorkerInfo) {
+ serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
+ }
+
+ override def removeWorker(worker: WorkerInfo) {
+ zk.delete(WORKING_DIR + "/worker_" + worker.id)
+ }
+
+ override def close() {
+ zk.close()
+ }
+
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
+ val appFiles = sortedFiles.filter(_.startsWith("app_"))
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ (apps, workers)
+ }
+
+ private def serializeIntoFile(path: String, value: AnyRef) {
+ val serializer = serialization.findSerializerFor(value)
+ val serialized = serializer.toBinary(value)
+ zk.create(path, serialized, CreateMode.PERSISTENT)
+ }
+
+ def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
+ val fileData = zk.getData("/spark/master_status/" + filename)
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
+ val serializer = serialization.serializerFor(clazz)
+ serializer.fromBinary(fileData).asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index e3dc30eefc..fff9cb60c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -43,7 +43,8 @@ private[spark] class ExecutorRunner(
val workerId: String,
val host: String,
val sparkHome: File,
- val workDir: File)
+ val workDir: File,
+ var state: ExecutorState.Value)
extends Logging {
val fullId = appId + "/" + execId
@@ -83,7 +84,8 @@ private[spark] class ExecutorRunner(
process.destroy()
process.waitFor()
}
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
+ state = ExecutorState.KILLED
+ worker ! ExecutorStateChanged(appId, execId, state, None, None)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@@ -102,7 +104,7 @@ private[spark] class ExecutorRunner(
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
- command.arguments.map(substituteVariables)
+ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/**
@@ -180,9 +182,9 @@ private[spark] class ExecutorRunner(
// long-lived processes only. However, in the future, we might restart the executor a few
// times on the same machine.
val exitCode = process.waitFor()
+ state = ExecutorState.FAILED
val message = "Command exited with code " + exitCode
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
- Some(exitCode))
+ worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))
} catch {
case interrupted: InterruptedException =>
logInfo("Runner thread for executor " + fullId + " interrupted")
@@ -192,8 +194,9 @@ private[spark] class ExecutorRunner(
if (process != null) {
process.destroy()
}
+ state = ExecutorState.FAILED
val message = e.getClass + ": " + e.getMessage
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
+ worker ! ExecutorStateChanged(appId, execId, state, Some(message), None)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 3904b701b2..991b22d9f8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -23,26 +23,42 @@ import java.io.File
import scala.collection.mutable.HashMap
import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
-import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.actor._
import akka.remote.{RemotingLifecycleEvent, AssociationErrorEvent, DisassociatedEvent}
import org.apache.spark.Logging
-import org.apache.spark.deploy.ExecutorState
+import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
-
-
+import org.apache.spark.deploy.DeployMessages.WorkerStateResponse
+import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
+import org.apache.spark.deploy.DeployMessages.KillExecutor
+import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import scala.Some
+import akka.remote.DisassociatedEvent
+import org.apache.spark.deploy.DeployMessages.LaunchExecutor
+import org.apache.spark.deploy.DeployMessages.RegisterWorker
+import org.apache.spark.deploy.DeployMessages.WorkerSchedulerStateResponse
+import org.apache.spark.deploy.DeployMessages.MasterChanged
+import org.apache.spark.deploy.DeployMessages.Heartbeat
+import org.apache.spark.deploy.DeployMessages.RegisteredWorker
+import akka.actor.Terminated
+
+/**
+ * @param masterUrls Each url should look like spark://host:port.
+ */
private[spark] class Worker(
host: String,
port: Int,
webUiPort: Int,
cores: Int,
memory: Int,
- masterUrl: String,
+ masterUrls: Array[String],
workDirPath: String = null)
extends Actor with Logging {
@@ -54,8 +70,18 @@ private[spark] class Worker(
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
+ val REGISTRATION_TIMEOUT = 20.seconds
+ val REGISTRATION_RETRIES = 3
+
+ // Index into masterUrls that we're currently trying to register with.
+ var masterIndex = 0
+
+ val masterLock: Object = new Object()
var master: ActorRef = null
- var masterWebUiUrl : String = ""
+ var activeMasterUrl: String = ""
+ var activeMasterWebUiUrl : String = ""
+ @volatile var registered = false
+ @volatile var connected = false
val workerId = generateWorkerId()
var sparkHome: File = null
var workDir: File = null
@@ -95,6 +121,7 @@ private[spark] class Worker(
}
override def preStart() {
+ assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
@@ -103,46 +130,100 @@ private[spark] class Worker(
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
webUi.start()
- connectToMaster()
+ registerWithMaster()
metricsSystem.registerSource(workerSource)
metricsSystem.start()
}
- def connectToMaster() {
- logInfo("Connecting to master " + masterUrl)
- master = context.actorFor(Master.toAkkaUrl(masterUrl))
- master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ def changeMaster(url: String, uiUrl: String) {
+ masterLock.synchronized {
+ activeMasterUrl = url
+ activeMasterWebUiUrl = uiUrl
+ master = context.actorFor(Master.toAkkaUrl(activeMasterUrl))
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ connected = true
+ }
+ }
+
+ def tryRegisterAllMasters() {
+ for (masterUrl <- masterUrls) {
+ logInfo("Connecting to master " + masterUrl + "...")
+ val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
+ publicAddress)
+ }
+ }
+
+ def registerWithMaster() {
+ tryRegisterAllMasters()
+
+ var retries = 0
+ lazy val retryTimer: Cancellable =
+ context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ retries += 1
+ if (registered) {
+ retryTimer.cancel()
+ } else if (retries >= REGISTRATION_RETRIES) {
+ logError("All masters are unresponsive! Giving up.")
+ System.exit(1)
+ } else {
+ tryRegisterAllMasters()
+ }
+ }
+ retryTimer // start timer
}
import context.dispatcher
override def receive = {
- case RegisteredWorker(url) =>
- masterWebUiUrl = url
- logInfo("Successfully registered with master")
- context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
- master ! Heartbeat(workerId)
+ case RegisteredWorker(masterUrl, masterWebUiUrl) =>
+ logInfo("Successfully registered with master " + masterUrl)
+ registered = true
+ changeMaster(masterUrl, masterWebUiUrl)
+ context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
+
+ case SendHeartbeat =>
+ masterLock.synchronized {
+ if (connected) { master ! Heartbeat(workerId) }
}
+ case MasterChanged(masterUrl, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterUrl)
+ context.unwatch(master)
+ changeMaster(masterUrl, masterWebUiUrl)
+
+ val execs = executors.values.
+ map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
+ sender ! WorkerSchedulerStateResponse(workerId, execs.toList)
+
case RegisterWorkerFailed(message) =>
- logError("Worker registration failed: " + message)
- System.exit(1)
-
- case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
- logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
- val manager = new ExecutorRunner(
- appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir)
- executors(appId + "/" + execId) = manager
- manager.start()
- coresUsed += cores_
- memoryUsed += memory_
- master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
+ if (!registered) {
+ logError("Worker registration failed: " + message)
+ System.exit(1)
+ }
+
+ case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
+ if (masterUrl != activeMasterUrl) {
+ logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
+ } else {
+ logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+ val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
+ self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING)
+ executors(appId + "/" + execId) = manager
+ manager.start()
+ coresUsed += cores_
+ memoryUsed += memory_
+ masterLock.synchronized {
+ master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
+ }
+ }
case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
- master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ masterLock.synchronized {
+ master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ }
val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
@@ -155,14 +236,18 @@ private[spark] class Worker(
memoryUsed -= executor.memory
}
- case KillExecutor(appId, execId) =>
- val fullId = appId + "/" + execId
- executors.get(fullId) match {
- case Some(executor) =>
- logInfo("Asked to kill executor " + fullId)
- executor.kill()
- case None =>
- logInfo("Asked to kill unknown executor " + fullId)
+ case KillExecutor(masterUrl, appId, execId) =>
+ if (masterUrl != activeMasterUrl) {
+ logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor " + execId)
+ } else {
+ val fullId = appId + "/" + execId
+ executors.get(fullId) match {
+ case Some(executor) =>
+ logInfo("Asked to kill executor " + fullId)
+ executor.kill()
+ case None =>
+ logInfo("Asked to kill unknown executor " + fullId)
+ }
}
case DisassociatedEvent(_, _, _) =>
@@ -170,17 +255,14 @@ private[spark] class Worker(
case RequestWorkerState => {
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
- finishedExecutors.values.toList, masterUrl, cores, memory,
- coresUsed, memoryUsed, masterWebUiUrl)
+ finishedExecutors.values.toList, activeMasterUrl, cores, memory,
+ coresUsed, memoryUsed, activeMasterWebUiUrl)
}
}
def masterDisconnected() {
- // TODO: It would be nice to try to reconnect to the master, but just shut down for now.
- // (Note that if reconnecting we would also need to assign IDs differently.)
- logError("Connection to master failed! Shutting down.")
- executors.values.foreach(_.kill())
- System.exit(1)
+ logError("Connection to master failed! Waiting for master to reconnect...")
+ connected = false
}
def generateWorkerId(): String = {
@@ -198,17 +280,18 @@ private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
- args.memory, args.master, args.workDir)
+ args.memory, args.masters, args.workDir)
actorSystem.awaitTermination()
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
- masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None)
+ : (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrl, workDir), name = "Worker")
+ masterUrls, workDir), name = "Worker")
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 0ae89a864f..3ed528e6b3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -29,7 +29,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
var webUiPort = 8081
var cores = inferDefaultCores()
var memory = inferDefaultMemory()
- var master: String = null
+ var masters: Array[String] = null
var workDir: String = null
// Check for settings in environment variables
@@ -86,14 +86,14 @@ private[spark] class WorkerArguments(args: Array[String]) {
printUsageAndExit(0)
case value :: tail =>
- if (master != null) { // Two positional arguments were given
+ if (masters != null) { // Two positional arguments were given
printUsageAndExit(1)
}
- master = value
+ masters = value.stripPrefix("spark://").split(",").map("spark://" + _)
parse(tail)
case Nil =>
- if (master == null) { // No positional argument was given
+ if (masters == null) { // No positional argument was given
printUsageAndExit(1)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 07bc479c83..a38e32b339 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -108,7 +108,7 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node>
- val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p>
+ val linkToMaster = <p><a href={worker.activeMasterWebUiUrl}>Back to Master</a></p>
val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span>
diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index f705a5631a..73fa7d6b6a 100644
--- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -24,23 +24,15 @@ import akka.remote._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
import akka.remote.DisassociatedEvent
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
import akka.remote.AssociationErrorEvent
import akka.remote.DisassociatedEvent
import akka.actor.Terminated
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
-private[spark] class StandaloneExecutorBackend(
+private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
hostPort: String,
@@ -75,15 +67,28 @@ private[spark] class StandaloneExecutorBackend(
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
if (executor == null) {
- logError("Received launchTask but executor was null")
+ logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}
+ case KillTask(taskId, _) =>
+ if (executor == null) {
+ logError("Received KillTask command but executor was null")
+ System.exit(1)
+ } else {
+ executor.killTask(taskId)
+ }
+
case DisassociatedEvent(_, _, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
+
+ case StopExecutor =>
+ logInfo("Driver commanded a shutdown")
+ context.stop(self)
+ context.system.shutdown()
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
@@ -91,7 +96,7 @@ private[spark] class StandaloneExecutorBackend(
}
}
-private[spark] object StandaloneExecutorBackend {
+private[spark] object CoarseGrainedExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Debug code
Utils.checkHost(hostname)
@@ -102,16 +107,19 @@ private[spark] object StandaloneExecutorBackend {
// set it
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
+
actorSystem.actorOf(
- Props(classOf[StandaloneExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
+ Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
if (args.length < 4) {
- //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
- System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
+ //the reason we allow the last appid argument is to make it easy to kill rogue executors
+ System.err.println(
+ "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
+ "[<appid>]")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3800063234..de4540493a 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -25,9 +25,10 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
-import org.apache.spark.scheduler._
import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
/**
@@ -36,7 +37,8 @@ import org.apache.spark.util.Utils
private[spark] class Executor(
executorId: String,
slaveHostname: String,
- properties: Seq[(String, String)])
+ properties: Seq[(String, String)],
+ isLocal: Boolean = false)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -73,46 +75,75 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
- // Make any thread terminations due to uncaught exceptions kill the entire
- // executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(
- new Thread.UncaughtExceptionHandler {
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ if (!isLocal) {
+ // Setup an uncaught exception handler for non-local mode.
+ // Make any thread terminations due to uncaught exceptions kill the entire
+ // executor process to avoid surprising stalls.
+ Thread.setDefaultUncaughtExceptionHandler(
+ new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(ExecutorExitCode.OOM)
+ } else {
+ System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ }
}
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
- }
- )
+ )
+ }
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
- SparkEnv.set(env)
- env.metricsSystem.registerSource(executorSource)
+ private val env = {
+ if (!isLocal) {
+ val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
+ isDriver = false, isLocal = false)
+ SparkEnv.set(_env)
+ _env.metricsSystem.registerSource(executorSource)
+ _env
+ } else {
+ SparkEnv.get
+ }
+ }
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
// Start worker thread pool
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+ val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
+
+ // Maintains the list of running tasks.
+ private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+
+ val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
- threadPool.execute(new TaskRunner(context, taskId, serializedTask))
+ val tr = new TaskRunner(context, taskId, serializedTask)
+ runningTasks.put(taskId, tr)
+ threadPool.execute(tr)
+ }
+
+ def killTask(taskId: Long) {
+ val tr = runningTasks.get(taskId)
+ if (tr != null) {
+ tr.kill()
+ // We remove the task also in the finally block in TaskRunner.run.
+ // The reason we need to remove it here is because killTask might be called before the task
+ // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
+ // idempotent.
+ runningTasks.remove(taskId)
+ }
}
/** Get the Yarn approved local directories. */
@@ -124,56 +155,87 @@ private[spark] class Executor(
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
.getOrElse(""))
- if (localDirs.isEmpty()) {
+ if (localDirs.isEmpty) {
throw new Exception("Yarn Local dirs can't be empty")
}
- return localDirs
+ localDirs
}
- class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
+ class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
- override def run() {
+ @volatile private var killed = false
+ @volatile private var task: Task[Any] = _
+
+ def kill() {
+ logInfo("Executor is trying to kill task " + taskId)
+ killed = true
+ if (task != null) {
+ task.kill()
+ }
+ }
+
+ override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
- context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
+ execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var attemptedTask: Option[Task[Any]] = None
var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
+ def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+ val startGCTime = gcTime
try {
SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
- val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+
+ // If this task has been killed before we deserialized it, let's quit now. Otherwise,
+ // continue executing the task.
+ if (killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
attemptedTask = Some(task)
- logInfo("Its epoch is " + task.epoch)
+ logDebug("Task " + taskId +"'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
+
+ // Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
+
+ // If the task has been killed, let's fail it.
+ if (task.killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
for (m <- task.metrics) {
- m.hostname = Utils.localHostName
+ m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
- //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
- // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
- // just change the relevants bytes in the byte buffer
+ // TODO I'd also like to track the time it takes to serialize the task results, but that is
+ // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
+ // custom serialized format, we could just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
+
val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
logInfo("Storing result for " + taskId + " in local BlockManager")
- val blockId = "taskresult_" + taskId
+ val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
ser.serialize(new IndirectTaskResult[Any](blockId))
@@ -182,12 +244,13 @@ private[spark] class Executor(
serializedDirectResult
}
}
- context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
+
+ execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
case t: Throwable => {
@@ -195,10 +258,10 @@ private[spark] class Executor(
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
// have left some weird state around depending on when the exception was thrown, but on
@@ -206,6 +269,8 @@ private[spark] class Executor(
logError("Exception in task ID " + taskId, t)
//System.exit(1)
}
+ } finally {
+ runningTasks.remove(taskId)
}
}
}
@@ -215,7 +280,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
- var loader = this.getClass.getClassLoader
+ val loader = this.getClass.getClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
@@ -237,7 +302,7 @@ private[spark] class Executor(
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- return constructor.newInstance(classUri, parent)
+ constructor.newInstance(classUri, parent)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@@ -245,7 +310,7 @@ private[spark] class Executor(
null
}
} else {
- return parent
+ parent
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index da62091980..b56d8c9912 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -18,14 +18,18 @@
package org.apache.spark.executor
import java.nio.ByteBuffer
-import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
-import org.apache.spark.TaskState.TaskState
+
import com.google.protobuf.ByteString
-import org.apache.spark.{Logging}
+
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
+
+import org.apache.spark.Logging
import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.Utils
+
private[spark] class MesosExecutorBackend
extends MesosExecutor
with ExecutorBackend
@@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
}
override def killTask(d: ExecutorDriver, t: TaskID) {
- logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
+ if (executor == null) {
+ logError("Received KillTask but executor was null")
+ } else {
+ executor.killTask(t.getValue.toLong)
+ }
}
override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index f311141148..0b4892f98f 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable {
* Number of bytes written for a shuffle
*/
var shuffleBytesWritten: Long = _
+
+ /**
+ * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ */
+ var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index c24fd48c04..703bc6a9ca 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
- implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+ implicit val futureExecContext = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
index 3c29700920..1b9fa1e53a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -20,17 +20,18 @@ package org.apache.spark.network.netty
import io.netty.buffer._
import org.apache.spark.Logging
+import org.apache.spark.storage.{TestBlockId, BlockId}
private[spark] class FileHeader (
val fileLen: Int,
- val blockId: String) extends Logging {
+ val blockId: BlockId) extends Logging {
lazy val buffer = {
val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen)
- buf.writeInt(blockId.length)
- blockId.foreach((x: Char) => buf.writeByte(x))
+ buf.writeInt(blockId.name.length)
+ blockId.name.foreach((x: Char) => buf.writeByte(x))
//padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
@@ -57,18 +58,15 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char]
}
- val blockId = idBuilder.toString()
+ val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId)
}
-
- def main (args:Array[String]){
-
- val header = new FileHeader(25,"block_0");
- val buf = header.buffer;
- val newheader = FileHeader.create(buf);
- System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
-
+ def main (args:Array[String]) {
+ val header = new FileHeader(25, TestBlockId("my_block"))
+ val buf = header.buffer
+ val newHeader = FileHeader.create(buf)
+ System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index 9493ccffd9..481ff8c3e0 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -27,12 +27,13 @@ import org.apache.spark.Logging
import org.apache.spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
+import org.apache.spark.storage.BlockId
private[spark] class ShuffleCopier extends Logging {
- def getBlock(host: String, port: Int, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(host: String, port: Int, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
@@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
try {
fc.init()
fc.connect(host, port)
- fc.sendRequest(blockId)
+ fc.sendRequest(blockId.name)
fc.waitForClose()
fc.close()
} catch {
@@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
}
}
- def getBlock(cmId: ConnectionManagerId, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
- blocks: Seq[(String, Long)],
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ blocks: Seq[(BlockId, Long)],
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
for ((blockId, size) <- blocks) {
getBlock(cmId, blockId, resultCollectCallback)
@@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging {
private[spark] object ShuffleCopier extends Logging {
- private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
@@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging {
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
- override def handleError(blockId: String) {
+ override def handleError(blockId: BlockId) {
if (!isComplete) {
resultCollectCallBack(blockId, -1, null)
}
}
}
- def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
@@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
}
val host = args(0)
val port = args(1).toInt
- val file = args(2)
+ val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
@@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging {
Executors.callable(new Runnable() {
def run() {
val copier = new ShuffleCopier()
- copier.getBlock(host, port, file, echoResultCollectCallBack)
+ copier.getBlock(host, port, blockId, echoResultCollectCallBack)
}
})
}).asJava
copiers.invokeAll(tasks)
- copiers.shutdown
+ copiers.shutdown()
System.exit(0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 8afcbe190a..546d921067 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,6 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
+import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -53,8 +54,8 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!blockId.startsWith("shuffle_")) {
+ override def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
// Figure out which local directory it hashes to, and which subdirectory in that
@@ -62,8 +63,8 @@ private[spark] object ShuffleSender {
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
- val file = new File(subDir, blockId)
- return file.getAbsolutePath
+ val file = new File(subDir, blockId.name)
+ return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index f132e2b735..70a5a8caff 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+package org.apache
+
/**
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
new file mode 100644
index 0000000000..44c5078621
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+import scala.reflect.ClassTag
+
+/**
+ * A set of asynchronous RDD actions available through an implicit conversion.
+ * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
+ */
+class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
+
+ /**
+ * Returns a future for counting the number of elements in the RDD.
+ */
+ def countAsync(): FutureAction[Long] = {
+ val totalCount = new AtomicLong
+ self.context.submitJob(
+ self,
+ (iter: Iterator[T]) => {
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next()
+ }
+ result
+ },
+ Range(0, self.partitions.size),
+ (index: Int, data: Long) => totalCount.addAndGet(data),
+ totalCount.get())
+ }
+
+ /**
+ * Returns a future for retrieving all elements of this RDD.
+ */
+ def collectAsync(): FutureAction[Seq[T]] = {
+ val results = new Array[Array[T]](self.partitions.size)
+ self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
+ (index, data) => results(index) = data, results.flatten.toSeq)
+ }
+
+ /**
+ * Returns a future for retrieving the first num elements of the RDD.
+ */
+ def takeAsync(num: Int): FutureAction[Seq[T]] = {
+ val f = new ComplexFutureAction[Seq[T]]
+
+ f.run {
+ val results = new ArrayBuffer[T](num)
+ val totalParts = self.partitions.length
+ var partsScanned = 0
+ while (results.size < num && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (results.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = num - results.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+
+ val buf = new Array[Array[T]](p.size)
+ f.runJob(self,
+ (it: Iterator[T]) => it.take(left).toArray,
+ p,
+ (index: Int, data: Array[T]) => buf(index) = data,
+ Unit)
+
+ buf.foreach(results ++= _.take(num - results.size))
+ partsScanned += numPartsToTry
+ }
+ results.toSeq
+ }
+
+ f
+ }
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
+ def foreachAsync(f: T => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+
+ /**
+ * Applies a function f to each partition of this RDD.
+ */
+ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index fe2946bcbe..63b9fe1478 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -18,15 +18,15 @@
package org.apache.spark.rdd
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.{BlockId, BlockManager}
import scala.reflect.ClassTag
-private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
+private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
}
private[spark]
-class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String])
+class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 3f4d4ad46a..99ea6e8ee8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{NullWritable, BytesWritable}
@@ -84,9 +85,9 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val env = SparkEnv.get
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
+ val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration())
- val finalOutputName = splitIdToFile(ctx.splitId)
+ val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
@@ -123,7 +124,7 @@ private[spark] object CheckpointRDD extends Logging {
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val env = SparkEnv.get
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
@@ -146,7 +147,7 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 0187256a8e..911a002884 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -18,13 +18,12 @@
package org.apache.spark.rdd
import java.io.{ObjectOutputStream, IOException}
-import java.util.{HashMap => JHashMap}
-import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark.util.AppendOnlyMap
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -105,17 +104,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
+ val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
- def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- val seq = map.get(k)
- if (seq != null) {
- seq
- } else {
- val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
- map.put(k, seq)
- seq
- }
+ val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
+ }
+
+ val getSeq = (k: K) => {
+ map.changeValue(k, update)
}
val ser = SparkEnv.get.serializerManager.get(serializerClass)
@@ -129,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
+ fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
kv => getSeq(kv._1)(depNum) += kv._2
}
}
}
- JavaConversions.mapAsScalaMap(map).iterator
+ new InterruptibleIterator(context, map.iterator)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index d3b3fffd40..32901a508f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -27,54 +27,19 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv,
- TaskContext}
+import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
-/**
- * An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file
- * system, or S3).
- * This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same
- * across multiple reads; the 'path' is the only variable that is different across new JobConfs
- * created from the Configuration.
- */
-class HadoopFileRDD[K, V](
- sc: SparkContext,
- path: String,
- broadcastedConf: Broadcast[SerializableWritable[Configuration]],
- inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K],
- valueClass: Class[V],
- minSplits: Int)
- extends HadoopRDD[K, V](sc, broadcastedConf, inputFormatClass, keyClass, valueClass, minSplits) {
-
- override def getJobConf(): JobConf = {
- if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
- // getJobConf() has been called previously, so there is already a local cache of the JobConf
- // needed by this RDD.
- return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
- } else {
- // Create a new JobConf, set the input file/directory paths to read from, and cache the
- // JobConf (i.e., in a shared hash map in the slave's JVM process that's accessible through
- // HadoopRDD.putCachedMetadata()), so that we only create one copy across multiple
- // getJobConf() calls for this RDD in the local process.
- // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
- val newJobConf = new JobConf(broadcastedConf.value.value)
- FileInputFormat.setInputPaths(newJobConf, path)
- HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
- return newJobConf
- }
- }
-}
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
extends Partition {
-
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -83,11 +48,24 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
}
/**
- * A base class that provides core functionality for reading data partitions stored in Hadoop.
+ * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
+ * sources in HBase, or S3).
+ *
+ * @param sc The SparkContext to associate the RDD with.
+ * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
+ * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
+ * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
+ * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
+ * creates.
+ * @param inputFormatClass Storage format of the data to be read.
+ * @param keyClass Class of the key associated with the inputFormatClass.
+ * @param valueClass Class of the value associated with the inputFormatClass.
+ * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
*/
class HadoopRDD[K, V](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
+ initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
@@ -105,6 +83,7 @@ class HadoopRDD[K, V](
sc,
sc.broadcast(new SerializableWritable(conf))
.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
+ None /* initLocalJobConfFuncOpt */,
inputFormatClass,
keyClass,
valueClass,
@@ -130,6 +109,7 @@ class HadoopRDD[K, V](
// local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
// The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
val newJobConf = new JobConf(broadcastedConf.value.value)
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
return newJobConf
}
@@ -164,38 +144,41 @@ class HadoopRDD[K, V](
array
}
- override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopPartition]
- logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
-
- val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => closeIfNeeded() }
-
- val key: K = reader.createKey()
- val value: V = reader.createValue()
-
- override def getNext() = {
- try {
- finished = !reader.next(key, value)
- } catch {
- case eof: EOFException =>
- finished = true
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new NextIterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopPartition]
+ logInfo("Input split: " + split.inputSplit)
+ var reader: RecordReader[K, V] = null
+
+ val jobConf = getJobConf()
+ val inputFormat = getInputFormat(jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
+ override def getNext() = {
+ try {
+ finished = !reader.next(key, value)
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ (key, value)
}
- (key, value)
- }
- override def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ override def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator[(K, V)](context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
@@ -216,10 +199,10 @@ private[spark] object HadoopRDD {
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
*/
- def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key)
+ def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key)
- def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key)
+ def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key)
def putCachedMetadata(key: String, value: Any) =
- SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value)
+ SparkEnv.get.hadoopJobMetadata.put(key, value)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
index 3cf22851dd..67636751bb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
@@ -22,14 +22,14 @@ import scala.reflect.ClassTag
/**
- * A variant of the MapPartitionsRDD that passes the partition index into the
- * closure. This can be used to generate or collect partition specific
- * information such as the number of tuples in a partition.
+ * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
+ * TaskContext, the closure can either get access to the interruptible flag or get the index
+ * of the partition in the RDD.
*/
private[spark]
-class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
+class MapPartitionsWithContextRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
- f: (Int, Iterator[T]) => Iterator[U],
+ f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
@@ -38,5 +38,5 @@ class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def compute(split: Partition, context: TaskContext) =
- f(split.index, firstParent[T].iterator(split, context))
+ f(context, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7b3a89f7e0..2662d48c84 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
private[spark]
@@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
result
}
- override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopPartition]
- logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- if (format.isInstanceOf[Configurable]) {
- format.asInstanceOf[Configurable].setConf(conf)
- }
- val reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => close())
-
- var havePair = false
- var finished = false
-
- override def hasNext: Boolean = {
- if (!finished && !havePair) {
- finished = !reader.nextKeyValue
- havePair = !finished
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopPartition]
+ logInfo("Input split: " + split.serializableHadoopSplit)
+ val conf = confBroadcast.value.value
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ if (format.isInstanceOf[Configurable]) {
+ format.asInstanceOf[Configurable].setConf(conf)
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => close())
+
+ var havePair = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !havePair) {
+ finished = !reader.nextKeyValue
+ havePair = !finished
+ }
+ !finished
}
- !finished
- }
- override def next: (K, V) = {
- if (!hasNext) {
- throw new java.util.NoSuchElementException("End of stream")
+ override def next(): (K, V) = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ havePair = false
+ (reader.getCurrentKey, reader.getCurrentValue)
}
- havePair = false
- return (reader.getCurrentKey, reader.getCurrentValue)
- }
- private def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ private def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator(context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c8e623081a..0c2a051a42 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -85,18 +85,24 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
- self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ self.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
- partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
+ partitioned.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
+ }, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
- values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ values.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
}
}
@@ -565,7 +571,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -664,7 +670,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- writer.setup(context.stageId, context.splitId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
var count = 0
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 78fe0cdcdb..09d0a8189d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -96,8 +96,9 @@ private[spark] class ParallelCollectionRDD[T: ClassTag](
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
- override def compute(s: Partition, context: TaskContext) =
- s.asInstanceOf[ParallelCollectionPartition[T]].iterator
+ override def compute(s: Partition, context: TaskContext) = {
+ new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
+ }
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 731ef90c90..3c237ca20a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -269,6 +269,19 @@ abstract class RDD[T: ClassTag](
def distinct(): RDD[T] = distinct(partitions.size)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): RDD[T] = {
+ coalesce(numPartitions, true)
+ }
+
+ /**
* Return a new RDD that is reduced into `numPartitions` partitions.
*
* This results in a narrow dependency, e.g. if you go from 1000 partitions
@@ -421,26 +434,39 @@ abstract class RDD[T: ClassTag](
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
+ preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[U: ClassTag](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
+ new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD. This is a variant of
+ * mapPartitions that also passes the TaskContext into the closure.
+ */
+ def mapPartitionsWithContext[U: ClassTag](
+ f: (TaskContext, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = {
+ new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
@@ -448,22 +474,23 @@ abstract class RDD[T: ClassTag](
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassTag](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ mapPartitionsWithIndex(f, preservesPartitioning)
+ }
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => U): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.map(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def mapWith[A: ClassTag, U: ClassTag]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => U): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -471,13 +498,14 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => Seq[U]): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def flatMapWith[A: ClassTag, U: ClassTag]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => Seq[U]): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.flatMap(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -485,13 +513,12 @@ abstract class RDD[T: ClassTag](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassTag](constructA: Int => A)
- (f:(T, A) => Unit) {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.map(t => {f(t, a); t})
- }
- (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
+ def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => {f(t, a); t})
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
}
/**
@@ -499,13 +526,12 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassTag](constructA: Int => A)
- (p:(T, A) => Boolean): RDD[T] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.filter(t => p(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
+ def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.filter(t => p(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
}
/**
@@ -544,16 +570,14 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+ sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
@@ -678,6 +702,8 @@ abstract class RDD[T: ClassTag](
*/
def count(): Long = {
sc.runJob(this, (iter: Iterator[T]) => {
+ // Use a while loop to count the number of elements rather than iter.size because
+ // iter.size uses a for loop, which is slightly slower in current version of Scala.
var result = 0L
while (iter.hasNext) {
result += 1L
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index b7205865cf..1d109a2496 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -57,7 +57,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
SparkEnv.get.serializerManager.get(serializerClass))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 85c512f3de..aab30b1bb4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -111,7 +111,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
case ShuffleCoGroupSplitDep(shuffleId) => {
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context.taskMetrics, serializer)
+ context, serializer)
iter.foreach(op)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 03fe0e00f9..ab7b3a2e24 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -29,8 +29,8 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
-import org.apache.spark.storage.{BlockManager, BlockManagerMaster}
-import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -42,34 +42,40 @@ import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
* locations to run each task on, based on the current cache status, and passes these to the
* low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
* lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
- * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task
+ * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
* THREADING: This class runs all its logic in a single thread executing the run() method, to which
- * events are submitted using a synchonized queue (eventQueue). The public API methods, such as
+ * events are submitted using a synchronized queue (eventQueue). The public API methods, such as
* runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
* should be private.
*/
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
- mapOutputTracker: MapOutputTracker,
+ mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
- extends TaskSchedulerListener with Logging {
+ extends Logging {
def this(taskSched: TaskScheduler) {
- this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ SparkEnv.get.blockManager.master, SparkEnv.get)
}
- taskSched.setListener(this)
+ taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo))
}
+ // Called to report that a task has completed and results are being fetched remotely.
+ def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
+ eventQueue.put(GettingResultEvent(task, taskInfo))
+ }
+
// Called by TaskScheduler to report task completions or failures.
- override def taskEnded(
+ def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -80,17 +86,18 @@ class DAGScheduler(
}
// Called by TaskScheduler when an executor fails.
- override def executorLost(execId: String) {
+ def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
- override def executorGained(execId: String, host: String) {
+ def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host))
}
- // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
- override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
+ // cancellation of the job itself.
+ def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -105,13 +112,15 @@ class DAGScheduler(
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
- val nextJobId = new AtomicInteger(0)
+ private[scheduler] val nextJobId = new AtomicInteger(0)
+
+ def numTotalJobs: Int = nextJobId.get()
- val nextStageId = new AtomicInteger(0)
+ private val nextStageId = new AtomicInteger(0)
- val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+ private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
- val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@@ -128,6 +137,7 @@ class DAGScheduler(
// stray messages to detect.
val failedEpoch = new HashMap[String, Long]
+ // stage id to the active job
val idToActiveJob = new HashMap[Int, ActiveJob]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
@@ -139,7 +149,7 @@ class DAGScheduler(
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
- val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
// Start a thread to run the DAGScheduler event loop
def start() {
@@ -157,7 +167,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
@@ -179,7 +189,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
+ val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@@ -192,6 +202,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
+ numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@@ -204,9 +215,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+ val stage =
+ new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
- stageToInfos(stage) = StageInfo(stage)
+ stageToInfos(stage) = new StageInfo(stage)
stage
}
@@ -262,32 +274,41 @@ class DAGScheduler(
}
/**
- * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
- * JobWaiter whose getResult() method will return the result of the job when it is complete.
- *
- * The job is assumed to have at least one partition; zero partition jobs should be handled
- * without a JobSubmitted event.
+ * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
+ * can be used to block until the the job finishes executing or can be used to cancel the job.
*/
- private[scheduler] def prepareJob[T, U: ClassTag](
- finalRdd: RDD[T],
+ def submitJob[T, U](
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
- properties: Properties = null)
- : (JobSubmitted, JobWaiter[U]) =
+ properties: Properties = null): JobWaiter[U] =
{
+ // Check to make sure we are not launching a task on a partition that does not exist.
+ val maxPartitions = rdd.partitions.length
+ partitions.find(p => p >= maxPartitions).foreach { p =>
+ throw new IllegalArgumentException(
+ "Attempting to access a non-existent partition: " + p + ". " +
+ "Total number of partitions: " + maxPartitions)
+ }
+
+ val jobId = nextJobId.getAndIncrement()
+ if (partitions.size == 0) {
+ return new JobWaiter[U](this, jobId, 0, resultHandler)
+ }
+
assert(partitions.size > 0)
- val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
- properties)
- (toSubmit, waiter)
+ val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
+ waiter, properties))
+ waiter
}
def runJob[T, U: ClassTag](
- finalRdd: RDD[T],
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
@@ -295,21 +316,7 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
- if (partitions.size == 0) {
- return
- }
-
- // Check to make sure we are not launching a task on a partition that does not exist.
- val maxPartitions = finalRdd.partitions.length
- partitions.find(p => p >= maxPartitions).foreach { p =>
- throw new IllegalArgumentException(
- "Attempting to access a non-existent partition: " + p + ". " +
- "Total number of partitions: " + maxPartitions)
- }
-
- val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
- finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
- eventQueue.put(toSubmit)
+ val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception, _) =>
@@ -330,19 +337,40 @@ class DAGScheduler(
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
+ val jobId = nextJobId.getAndIncrement()
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
+ listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
/**
+ * Cancel a job that is running or waiting in the queue.
+ */
+ def cancelJob(jobId: Int) {
+ logInfo("Asked to cancel job " + jobId)
+ eventQueue.put(JobCancelled(jobId))
+ }
+
+ def cancelJobGroup(groupId: String) {
+ logInfo("Asked to cancel job group " + groupId)
+ eventQueue.put(JobGroupCancelled(groupId))
+ }
+
+ /**
+ * Cancel all jobs that are running or waiting in the queue.
+ */
+ def cancelAllJobs() {
+ eventQueue.put(AllJobsCancelled)
+ }
+
+ /**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
- val jobId = nextJobId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
+ case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
+ val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -361,18 +389,43 @@ class DAGScheduler(
submitStage(finalStage)
}
+ case JobCancelled(jobId) =>
+ // Cancel a job: find all the running stages that are linked to this job, and cancel them.
+ running.filter(_.jobId == jobId).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
+ case JobGroupCancelled(groupId) =>
+ // Cancel all jobs belonging to this job group.
+ // First finds all active jobs with this group id, and then kill stages for them.
+ val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ .map(_.jobId)
+ if (!jobIds.isEmpty) {
+ running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+ }
+
+ case AllJobsCancelled =>
+ // Cancel all running jobs.
+ running.foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
case ExecutorLost(execId) =>
handleExecutorLost(execId)
- case begin: BeginEvent =>
- listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case BeginEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskStart(task, taskInfo))
+
+ case GettingResultEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
- case completion: CompletionEvent =>
- listenerBus.post(SparkListenerTaskEnd(
- completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
+ case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@@ -542,7 +595,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
- listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
+ listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@@ -563,9 +616,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
- if (!stage.submissionTime.isDefined) {
- stage.submissionTime = Some(System.currentTimeMillis())
- }
+ stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -579,15 +630,20 @@ class DAGScheduler(
*/
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
+
+ if (!stageIdToStage.contains(task.stageId)) {
+ // Skip all the actions if the stage has been cancelled.
+ return
+ }
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
- val serviceTime = stage.submissionTime match {
+ val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
- case _ => "Unkown"
+ case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.completionTime = Some(System.currentTimeMillis)
+ stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@@ -627,7 +683,7 @@ class DAGScheduler(
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
- stage.addOutputLoc(smt.partition, status)
+ stage.addOutputLoc(smt.partitionId, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
markStageAsFinished(stage)
@@ -753,14 +809,14 @@ class DAGScheduler(
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
- * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
+ * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
- failedStage.completionTime = Some(System.currentTimeMillis())
+ stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- val error = new SparkException("Job failed: " + reason)
+ val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
idToActiveJob -= resultStage.jobId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 10ff1b4376..708d221d60 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
-private[spark] sealed trait DAGSchedulerEvent
+private[scheduler] sealed trait DAGSchedulerEvent
-private[spark] case class JobSubmitted(
+private[scheduler] case class JobSubmitted(
+ jobId: Int,
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
@@ -43,9 +44,19 @@ private[spark] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
-private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
-private[spark] case class CompletionEvent(
+private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
+
+private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
+
+private[scheduler]
+case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
+private[scheduler]
+case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
+private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -54,10 +65,12 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler]
+case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
-private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
-private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+private[scheduler]
+case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
-private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
+private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 151514896f..7b5c0e29ad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
})
metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
- override def getValue: Int = dagScheduler.nextJobId.get()
+ override def getValue: Int = dagScheduler.numTotalJobs
})
metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 370ccd183c..1791ee660d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.deploy.SparkHadoopUtil
import scala.collection.immutable.Set
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.security.UserGroupInformation
@@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val conf = new JobConf(configuration)
- env.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
FileInputFormat.setInputPaths(conf, path)
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
@@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val jobConf = new JobConf(configuration)
- env.hadoop.addCredentials(jobConf)
+ SparkHadoopUtil.get.addCredentials(jobConf)
FileInputFormat.setInputPaths(jobConf, path)
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 3628b1b078..60927831a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -1,292 +1,384 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.PrintWriter
-import java.io.File
-import java.io.FileNotFoundException
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.executor.TaskMetrics
-
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
-class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
- createLogDir()
- def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
- // Create a folder for log files, the folder's name is the creation time of the jobLogger
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
- }
- }
-
- // Create a log file for one job, the file name is the jobID
- protected def createLogWriter(jobID: Int) {
- try{
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- // Close log file, and clean the stage relationship in stageIDToJobID
- protected def closeLogWriter(jobID: Int) =
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
- fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
- })
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
- }
-
- // Write log information to log file, withTime parameter controls whether to recored
- // time stamp for the information
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " +info
- }
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
- }
-
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
-
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- // Generate indents and convert to String
- protected def indentString(indent: Int) = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- protected def getRddName(rdd: RDD[_]) = {
- var rddName = rdd.getClass.getName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
- }
-
- protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
-
- // Record task metrics into job log files
- protected def recordTaskMetrics(stageID: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
- }
-
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
- }
-
- override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
- "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
- }
-
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
- val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
- case _ =>
- }
- }
-
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
- }
-
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
- val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage)
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{IOException, File, FileNotFoundException, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
+ * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
+ * after the SparkContext is created.
+ * Note that each JobLogger only works for one SparkContext
+ * @param logDirName The base directory for the log files.
+ */
+
+class JobLogger(val user: String, val logDirName: String)
+ extends SparkListener with Logging {
+
+ def this() = this(System.getProperty("user.name", "<unknown>"),
+ String.valueOf(System.currentTimeMillis()))
+
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark-%s".format(user)
+
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
+ /** Create a folder for log files, the folder's name is the creation time of jobLogger */
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ // JobLogger should throw a exception rather than continue to construct this object.
+ throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ /**
+ * Create a log file for one job
+ * @param jobID ID of the job
+ * @exception FileNotFoundException Fail to create log file
+ */
+ protected def createLogWriter(jobID: Int) {
+ try {
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ /**
+ * Close log file, and clean the stage relationship in stageIDToJobID
+ * @param jobID ID of the job
+ */
+ protected def closeLogWriter(jobID: Int) {
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+ }
+
+ /**
+ * Write info into log file
+ * @param jobID ID of the job
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ /**
+ * Write info into log file
+ * @param stageID ID of the stage
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+ }
+
+ /**
+ * Build stage dependency for a job
+ * @param jobID ID of the job
+ * @param stage Root stage of the job
+ */
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ /**
+ * Record stage dependency and RDD dependency for a stage
+ * @param jobID Job ID of the stage
+ */
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ /**
+ * Generate indents and convert to String
+ * @param indent Number of indents
+ * @return string of indents
+ */
+ protected def indentString(indent: Int): String = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ /**
+ * Get RDD's name
+ * @param rdd Input RDD
+ * @return String of RDD's name
+ */
+ protected def getRddName(rdd: RDD[_]): String = {
+ var rddName = rdd.getClass.getSimpleName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ /**
+ * Record RDD dependency graph in a stage
+ * @param jobID Job ID of the stage
+ * @param rdd Root RDD of the stage
+ * @param indent Indent number before info
+ */
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo =
+ if (rdd.getStorageLevel != StorageLevel.NONE) {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
+ rdd.origin + " " + rdd.generator
+ } else {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
+ rdd.origin + " " + rdd.generator
+ }
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+
+ /**
+ * Record stage dependency graph of a job
+ * @param jobID Job ID of the stage
+ * @param stage Root stage of the job
+ * @param indent Indent number before info, default is 0
+ */
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ if (!idSet.contains(stage.id)) {
+ idSet += stage.id
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
+ }
+ } else {
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+ }
+
+ /**
+ * Record task metrics into job log files, including execution info and shuffle metrics
+ * @param stageID Stage ID of the task
+ * @param status Status info of the task
+ * @param taskInfo Task description info
+ * @param taskMetrics Task running metrics
+ */
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ /**
+ * When stage is submitted, record stage submit info
+ * @param stageSubmitted Stage submitted event
+ */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
+ }
+
+ /**
+ * When stage is completed, record stage completion status
+ * @param stageCompleted Stage completed event
+ */
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * When task ends, record task completion status and metrics
+ * @param taskEnd Task end event
+ */
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ /**
+ * When job ends, recording job completion status and close log file
+ * @param jobEnd Job end event
+ */
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ /**
+ * Record job properties into job log file
+ * @param jobID ID of the job
+ * @param properties Properties of the job
+ */
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ /**
+ * When job starts, record job property and stage graph
+ * @param jobStart Job start event
+ */
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 200d881799..58f238d8cf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,48 +17,58 @@
package org.apache.spark.scheduler
-import scala.collection.mutable.ArrayBuffer
-
/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/
-private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+private[spark] class JobWaiter[T](
+ dagScheduler: DAGScheduler,
+ jobId: Int,
+ totalTasks: Int,
+ resultHandler: (Int, T) => Unit)
extends JobListener {
private var finishedTasks = 0
- private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
- private var jobResult: JobResult = null // If the job is finished, this will be its result
+ // Is the job as a whole finished (succeeded or failed)?
+ private var _jobFinished = totalTasks == 0
- override def taskSucceeded(index: Int, result: Any) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
- }
- resultHandler(index, result.asInstanceOf[T])
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- jobFinished = true
- jobResult = JobSucceeded
- this.notifyAll()
- }
- }
+ def jobFinished = _jobFinished
+
+ // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
+ // partition RDDs), we set the jobResult directly to JobSucceeded.
+ private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+
+ /**
+ * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
+ * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it
+ * will fail this job with a SparkException.
+ */
+ def cancel() {
+ dagScheduler.cancelJob(jobId)
}
- override def jobFailed(exception: Exception) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
- }
- jobFinished = true
- jobResult = JobFailed(exception, None)
+ override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
+ if (_jobFinished) {
+ throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ }
+ resultHandler(index, result.asInstanceOf[T])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ _jobFinished = true
+ jobResult = JobSucceeded
this.notifyAll()
}
}
+ override def jobFailed(exception: Exception): Unit = synchronized {
+ _jobFinished = true
+ jobResult = JobFailed(exception, None)
+ this.notifyAll()
+ }
+
def awaitResult(): JobResult = synchronized {
- while (!jobFinished) {
+ while (!_jobFinished) {
this.wait()
}
return jobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 9eb8d48501..596f9adde9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -43,7 +43,10 @@ private[spark] class Pool(
var runningTasks = 0
var priority = 0
- var stageId = 0
+
+ // A pool's stage id is used to break the tie in scheduling.
+ var stageId = -1
+
var name = poolName
var parent: Pool = null
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 07e8317e3a..310ec62ca8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -23,7 +23,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
-import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
private[spark] object ResultTask {
@@ -32,23 +32,23 @@ private[spark] object ResultTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
- return old
+ old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(func)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -56,11 +56,11 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
- return (rdd, func)
+ (rdd, func)
}
def clearCache() {
@@ -71,29 +71,37 @@ private[spark] object ResultTask {
}
+/**
+ * A task that sends back the output to the driver application.
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd input to func
+ * @param func a function to apply on a partition of the RDD
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
+ * input RDD's partitions).
+ */
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
- var partition: Int,
+ _partitionId: Int,
@transient locs: Seq[TaskLocation],
var outputId: Int)
- extends Task[U](stageId) with Externalizable {
+ extends Task[U](stageId, _partitionId) with Externalizable {
def this() = this(0, null, null, 0, null, 0)
- var split = if (rdd == null) {
- null
- } else {
- rdd.partitions(partition)
- }
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
- override def run(attemptId: Long): U = {
- val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
+ override def runTask(context: TaskContext): U = {
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
@@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U](
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+ override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
@@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U](
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
- partition = in.readInt()
+ partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 4e25086ec9..356fe56bf3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -30,7 +30,10 @@ import scala.xml.XML
* addTaskSetManager: build the leaf nodes(TaskSetManagers)
*/
private[spark] trait SchedulableBuilder {
+ def rootPool: Pool
+
def buildPools()
+
def addTaskSetManager(manager: Schedulable, properties: Properties)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index d23df0dd2b..1dc71a0428 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.storage._
-import org.apache.spark.util.{TimeStampedHashMap, MetadataCleaner}
+import org.apache.spark.util.{MetadataCleanerType, TimeStampedHashMap, MetadataCleaner}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
@@ -37,7 +37,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
@@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask {
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask {
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
- return (rdd, dep)
+ (rdd, dep)
}
}
@@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
- return (HashMap(set.toSeq: _*))
+ HashMap(set.toSeq: _*)
}
def clearCache() {
@@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask {
}
}
+/**
+ * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+ * specified in the ShuffleDependency).
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd the final RDD in this stage
+ * @param dep the ShuffleDependency
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ */
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
- var partition: Int,
+ _partitionId: Int,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId)
+ extends Task[MapStatus](stageId, _partitionId)
with Externalizable
with Logging {
@@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask(
if (locs == null) Nil else locs.toSet.toSeq
}
- var split = if (rdd == null) null else rdd.partitions(partition)
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeLong(epoch)
out.writeObject(split)
}
@@ -124,68 +136,70 @@ private[spark] class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
- partition = in.readInt()
+ partitionId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
- override def run(attemptId: Long): MapStatus = {
+ override def runTask(context: TaskContext): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
-
- val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
- metrics = Some(taskContext.taskMetrics)
+ metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
- var shuffle: ShuffleBlocks = null
- var buckets: ShuffleWriterGroup = null
+ val shuffleBlockManager = blockManager.shuffleBlockManager
+ var shuffle: ShuffleWriterGroup = null
+ var success = false
try {
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
- shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partition)
+ shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
- for (elem <- rdd.iterator(split, taskContext)) {
+ for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets.writers(bucketId).write(pair)
+ shuffle.writers(bucketId).write(pair)
}
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
- val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ var totalTime = 0L
+ val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
- writer.close()
- val size = writer.size()
+ val size = writer.fileSegment().length
totalBytes += size
+ totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ success = true
+ new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (buckets != null) {
- buckets.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null) {
+ shuffle.writers.foreach(_.revertPartialWrites())
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
- if (shuffle != null && buckets != null) {
- shuffle.releaseWriters(buckets)
+ if (shuffle != null && shuffle.writers != null) {
+ shuffle.writers.foreach(_.close())
+ shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
+ context.executeOnCompleteCallbacks()
}
}
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 62b521ad45..a35081f7b1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -24,13 +24,16 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
-case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
+case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
-case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+case class SparkListenerTaskGettingResult(
+ task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
@@ -54,7 +57,13 @@ trait SparkListener {
/**
* Called when a task starts
*/
- def onTaskStart(taskEnd: SparkListenerTaskStart) { }
+ def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * Called when a task begins remotely fetching its result (will not be called for tasks that do
+ * not need to fetch the result remotely).
+ */
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends
@@ -80,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
//shuffle write
@@ -93,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging {
//runtime breakdown
- val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
@@ -111,7 +120,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
- Distribution(stage.stageInfo.taskInfos.flatMap{
+ Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 4d3e4a17ba..d5824e7954 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index aa293dc6b3..7cb3fe46e5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
+ val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
@@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
-
- /** When first task was submitted to scheduler. */
- var submissionTime: Option[Long] = None
- var completionTime: Option[Long] = None
-
private var nextAttemptId = 0
def isAvailable: Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index b6f11969e5..93599dfdc8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,9 +21,16 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
-case class StageInfo(
- val stage: Stage,
+class StageInfo(
+ stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
- override def toString = stage.rdd.toString
+ val stageId = stage.id
+ /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
+ var submissionTime: Option[Long] = None
+ var completionTime: Option[Long] = None
+ val rddName = stage.rdd.name
+ val name = stage.name
+ val numPartitions = stage.numPartitions
+ val numTasks = stage.numTasks
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 598d91752a..69b42e86ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,25 +17,74 @@
package org.apache.spark.scheduler
-import org.apache.spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import org.apache.spark.util.ByteBufferInputStream
+
import scala.collection.mutable.HashMap
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.ByteBufferInputStream
+
/**
- * A task to execute on a worker node.
+ * A unit of execution. We have two kinds of Task's in Spark:
+ * - [[org.apache.spark.scheduler.ShuffleMapTask]]
+ * - [[org.apache.spark.scheduler.ResultTask]]
+ *
+ * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
+ * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
+ * and sends the task output back to the driver application. A ShuffleMapTask executes the task
+ * and divides the task output to multiple buckets (based on the task's partitioner).
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
- def run(attemptId: Long): T
+private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+
+ final def run(attemptId: Long): T = {
+ context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ if (_killed) {
+ kill()
+ }
+ runTask(context)
+ }
+
+ def runTask(context: TaskContext): T
+
def preferredLocations: Seq[TaskLocation] = Nil
- var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
+ // Map output tracker epoch. Will be set by TaskScheduler.
+ var epoch: Long = -1
var metrics: Option[TaskMetrics] = None
+ // Task context, to be initialized in run().
+ @transient protected var context: TaskContext = _
+
+ // A flag to indicate whether the task is killed. This is used in case context is not yet
+ // initialized when kill() is invoked.
+ @volatile @transient private var _killed = false
+
+ /**
+ * Whether the task has been killed.
+ */
+ def killed: Boolean = _killed
+
+ /**
+ * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
+ * code and user code to properly handle the flag. This function should be idempotent so it can
+ * be called multiple times.
+ */
+ def kill() {
+ _killed = true
+ if (context != null) {
+ context.interrupted = true
+ }
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 7c2a422aff..4bae26f3a6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -31,9 +31,25 @@ class TaskInfo(
val host: String,
val taskLocality: TaskLocality.TaskLocality) {
+ /**
+ * The time when the task started remotely getting the result. Will not be set if the
+ * task result was sent immediately when the task finished (as opposed to sending an
+ * IndirectTaskResult and later fetching the result from the block manager).
+ */
+ var gettingResultTime: Long = 0
+
+ /**
+ * The time when the task has completed successfully (including the time to remotely fetch
+ * results, if necessary).
+ */
var finishTime: Long = 0
+
var failed = false
+ def markGettingResult(time: Long = System.currentTimeMillis) {
+ gettingResultTime = time
+ }
+
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
@@ -43,6 +59,8 @@ class TaskInfo(
failed = true
}
+ def gettingResult: Boolean = gettingResultTime != 0
+
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
@@ -52,6 +70,8 @@ class TaskInfo(
def status: String = {
if (running)
"RUNNING"
+ else if (gettingResult)
+ "GET RESULT"
else if (failed)
"FAILED"
else if (successful)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index db3954a9d3..7e468d0d67 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -24,13 +24,14 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.{SparkEnv}
import java.nio.ByteBuffer
import org.apache.spark.util.Utils
+import org.apache.spark.storage.BlockId
// Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark]
-case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable
+case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 7c2a9f03d7..10e0478108 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler through
- * the TaskSchedulerListener interface.
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler.
*/
private[spark] trait TaskScheduler {
@@ -45,8 +44,11 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit
- // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
- def setListener(listener: TaskSchedulerListener): Unit
+ // Cancel a stage.
+ def cancelTasks(stageId: Int)
+
+ // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setDAGScheduler(dagScheduler: DAGScheduler): Unit
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
deleted file mode 100644
index 593fa9fb93..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import scala.collection.mutable.Map
-
-import org.apache.spark.TaskEndReason
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * Interface for getting events back from the TaskScheduler.
- */
-private[spark] trait TaskSchedulerListener {
- // A task has started.
- def taskStarted(task: Task[_], taskInfo: TaskInfo)
-
- // A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
-
- // A node was added to the cluster.
- def executorGained(execId: String, host: String): Unit
-
- // A node was lost from the cluster.
- def executorLost(execId: String): Unit
-
- // The TaskScheduler wants to abort an entire task set.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325156..03bf760837 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -31,5 +31,9 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt
+ def kill() {
+ tasks.foreach(_.kill())
+ }
+
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 1a844b7e7e..85033958ef 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -17,7 +17,6 @@
package org.apache.spark.scheduler.cluster
-import java.lang.{Boolean => JBoolean}
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
import java.util.{TimerTask, Timer}
@@ -79,14 +78,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String]
- // JAR server, if any JARs were added by the user to the SparkContext
- var jarServer: HttpServer = null
-
- // URIs of JARs to pass to executor
- var jarUris: String = ""
-
// Listener object to pass upcalls into
- var listener: TaskSchedulerListener = null
+ var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null
@@ -101,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
def initialize(context: SchedulerBackend) {
@@ -171,8 +164,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}
- def taskSetFinished(manager: TaskSetManager) {
- this.synchronized {
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
+ def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
+ // Check to see if the given task set has been removed. This is possible in the case of
+ // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
+ // more than one running tasks).
+ if (activeTaskSets.contains(manager.taskSet.id)) {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
@@ -281,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@@ -290,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
+ def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ taskSetManager.handleTaskGettingResult(tid)
+ }
+
def handleSuccessfulTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,
@@ -334,9 +354,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (backend != null) {
backend.stop()
}
- if (jarServer != null) {
- jarServer.stop()
- }
if (taskResultGetter != null) {
taskResultGetter.stop()
}
@@ -384,9 +401,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
- // Call listener.executorLost without holding the lock on this to prevent deadlock
+ // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
@@ -405,7 +422,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
def executorGained(execId: String, host: String) {
- listener.executorGained(execId, host)
+ dagScheduler.executorGained(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 194ab55102..ee47aaffca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -17,18 +17,16 @@
package org.apache.spark.scheduler.cluster
-import java.nio.ByteBuffer
-import java.util.{Arrays, NoSuchElementException}
+import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
-import scala.Some
import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- SparkException, Success, TaskEndReason, TaskResultLost, TaskState}
+ Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.util.{SystemClock, Clock}
@@ -417,11 +415,17 @@ private[spark] class ClusterTaskSetManager(
}
private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
+ }
+
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
}
/**
- * Marks the task as successful and notifies the listener that a task has ended.
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
@@ -431,7 +435,7 @@ private[spark] class ClusterTaskSetManager(
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.listener.taskEnded(
+ sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded.
@@ -447,7 +451,8 @@ private[spark] class ClusterTaskSetManager(
}
/**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
@@ -458,54 +463,57 @@ private[spark] class ClusterTaskSetManager(
val index = info.index
info.markFailed()
if (!successful(index)) {
- logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
reason.foreach {
- _ match {
- case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- successful(index) = true
- tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case ef: ExceptionFailure =>
- sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- val key = ef.description
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
+ case fetchFailed: FetchFailed =>
+ logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ successful(index) = true
+ tasksSuccessful += 1
+ sched.taskSetFinished(this)
+ removeAllRunningTasks()
+ return
+
+ case TaskKilled =>
+ logWarning("Task %d was killed.".format(tid))
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+ return
+
+ case ef: ExceptionFailure =>
+ sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ val key = ef.description
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
}
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ recentExceptions(key) = (0, now)
+ (true, 0)
}
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logWarning("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
- case TaskResultLost =>
- logInfo("Lost result for TID %s on host %s".format(tid, info.host))
- sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+ case TaskResultLost =>
+ logWarning("Lost result for TID %s on host %s".format(tid, info.host))
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
- case _ => {}
- }
+ case _ => {}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
@@ -532,7 +540,7 @@ private[spark] class ClusterTaskSetManager(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
@@ -605,7 +613,7 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
@@ -630,11 +638,11 @@ private[spark] class ClusterTaskSetManager(
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksSuccessful >= minFinishedForSpeculation) {
+ if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTime()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index c0b836bf1a..53316dae2a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -24,26 +24,28 @@ import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.util.{Utils, SerializableBuffer}
-private[spark] sealed trait StandaloneClusterMessage extends Serializable
+private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
-private[spark] object StandaloneClusterMessages {
+private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
- case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
+ case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage
+
+ case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
- extends StandaloneClusterMessage
+ extends CoarseGrainedClusterMessage
- case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
+ case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
// Executors to driver
case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
- extends StandaloneClusterMessage {
+ extends CoarseGrainedClusterMessage {
Utils.checkHostPort(hostPort, "Expected host port")
}
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
- data: SerializableBuffer) extends StandaloneClusterMessage
+ data: SerializableBuffer) extends CoarseGrainedClusterMessage
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
@@ -54,10 +56,14 @@ private[spark] object StandaloneClusterMessages {
}
// Internal messages in driver
- case object ReviveOffers extends StandaloneClusterMessage
+ case object ReviveOffers extends CoarseGrainedClusterMessage
+
+ case object StopDriver extends CoarseGrainedClusterMessage
+
+ case object StopExecutor extends CoarseGrainedClusterMessage
- case object StopDriver extends StandaloneClusterMessage
+ case object StopExecutors extends CoarseGrainedClusterMessage
- case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage
+ case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index b6f0ec961a..3ccc38d72b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -29,16 +29,19 @@ import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycle
import org.apache.spark.{SparkException, Logging, TaskState}
import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.Utils
/**
- * A standalone scheduler backend, which waits for standalone executors to connect to it through
- * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
- * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
+ * A scheduler backend that waits for coarse grained executors to connect to it through Akka.
+ * This backend holds onto each executor for the duration of the Spark job rather than relinquishing
+ * executors whenever a task is done and asking the scheduler to launch a new executor for
+ * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the
+ * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode
+ * (spark.deploy.*).
*/
private[spark]
-class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
@@ -84,17 +87,33 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(executorId) += 1
- makeOffers(executorId)
+ if (executorActor.contains(executorId)) {
+ freeCores(executorId) += 1
+ makeOffers(executorId)
+ } else {
+ // Ignoring the update since we don't know about the executor.
+ val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s"
+ logWarning(msg.format(taskId, state, sender, executorId))
+ }
}
case ReviveOffers =>
makeOffers()
+ case KillTask(taskId, executorId) =>
+ executorActor(executorId) ! KillTask(taskId, executorId)
+
case StopDriver =>
sender ! true
context.stop(self)
+ case StopExecutors =>
+ logInfo("Asking each executor to shut down")
+ for (executor <- executorActor.values) {
+ executor ! StopExecutor
+ }
+ sender ! true
+
case RemoveExecutor(executorId, reason) =>
removeExecutor(executorId, reason)
sender ! true
@@ -159,16 +178,31 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
driverActor = actorSystem.actorOf(
- Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
+ Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ private val timeout = {
+ Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ }
+
+ def stopExecutors() {
+ try {
+ if (driverActor != null) {
+ logInfo("Shutting down all executors")
+ val future = driverActor.ask(StopExecutors)(timeout)
+ Await.ready(future, timeout)
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error asking standalone scheduler to shut down executors", e)
+ }
+ }
override def stop() {
try {
if (driverActor != null) {
val future = driverActor.ask(StopDriver)(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
}
} catch {
case e: Exception =>
@@ -180,6 +214,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
driverActor ! ReviveOffers
}
+ override def killTask(taskId: Long, executorId: String) {
+ driverActor ! KillTask(taskId, executorId)
+ }
+
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
@@ -187,7 +225,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
def removeExecutor(executorId: String, reason: String) {
try {
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error notifying standalone scheduler's driver actor", e)
@@ -195,6 +233,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
-private[spark] object StandaloneSchedulerBackend {
- val ACTOR_NAME = "StandaloneScheduler"
+private[spark] object CoarseGrainedSchedulerBackend {
+ val ACTOR_NAME = "CoarseGrainedScheduler"
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
index d57eb3276f..5367218faa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler.cluster
-import org.apache.spark.{SparkContext}
+import org.apache.spark.SparkContext
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
@@ -30,8 +30,8 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
+ def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
+
// Memory used by each executor (in megabytes)
protected val executorMemory: Int = SparkContext.executorMemoryRequested
-
- // TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
new file mode 100644
index 0000000000..d78bdbaa7a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.spark.{Logging, SparkContext}
+
+private[spark] class SimrSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext,
+ driverFilePath: String)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ val tmpPath = new Path(driverFilePath + "_tmp")
+ val filePath = new Path(driverFilePath)
+
+ val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
+
+ override def start() {
+ super.start()
+
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+
+ logInfo("Writing to HDFS file: " + driverFilePath)
+ logInfo("Writing Akka address: " + driverUrl)
+
+ // Create temporary file to prevent race condition where executors get empty driverUrl file
+ val temp = fs.create(tmpPath, true)
+ temp.writeUTF(driverUrl)
+ temp.writeInt(maxCores)
+ temp.close()
+
+ // "Atomic" rename
+ fs.rename(tmpPath, filePath)
+ }
+
+ override def stop() {
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+ fs.delete(new Path(driverFilePath), false)
+ super.stopExecutors()
+ super.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index fa83ae19d6..7127a72d6d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -26,9 +26,9 @@ import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
- master: String,
+ masters: Array[String],
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with ClientListener
with Logging {
@@ -44,15 +44,15 @@ private[spark] class SparkDeploySchedulerBackend(
// The endpoint for executors to talk to us
val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command(
- "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
+ "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(null)
val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
"http://" + sc.ui.appUIAddress)
- client = new Client(sc.env.actorSystem, master, appDesc, this)
+ client = new Client(sc.env.actorSystem, masters, appDesc, this)
client.start()
}
@@ -71,8 +71,14 @@ private[spark] class SparkDeploySchedulerBackend(
override def disconnected() {
if (!stopping) {
- logError("Disconnected from Spark cluster!")
- scheduler.error("Disconnected from Spark cluster")
+ logWarning("Disconnected from Spark cluster! Waiting for reconnection...")
+ }
+ }
+
+ override def dead() {
+ if (!stopping) {
+ logError("Spark cluster looks dead, giving up.")
+ scheduler.error("Spark cluster looks down")
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
index b2a8f06472..e68c527713 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -24,33 +24,16 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.Utils
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging {
- private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
- private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
- private val getTaskResultExecutor = new ThreadPoolExecutor(
- MIN_THREADS,
- MAX_THREADS,
- 0L,
- TimeUnit.SECONDS,
- new LinkedBlockingDeque[Runnable],
- new ResultResolverThreadFactory)
-
- class ResultResolverThreadFactory extends ThreadFactory {
- private var counter = 0
- private var PREFIX = "Result resolver thread"
-
- override def newThread(r: Runnable): Thread = {
- val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
- counter += 1
- thread.setDaemon(true)
- return thread
- }
- }
+ private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
+ private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
+ THREADS, "Result resolver thread")
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
@@ -67,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) =>
logDebug("Fetching indirect task result for TID %s".format(tid))
+ scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index bf4040fafc..8de9b72b2f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,13 +30,14 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
* onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever
* a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the
- * StandaloneBackend mechanism. This class is useful for lower and more predictable latency.
+ * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable
+ * latency.
*
* Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to
* remove this.
@@ -46,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend(
sc: SparkContext,
master: String,
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with MScheduler
with Logging {
@@ -122,20 +123,20 @@ private[spark] class CoarseMesosSchedulerBackend(
val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val uri = System.getProperty("spark.executor.uri")
if (uri == null) {
val runScript = new File(sparkHome, "spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
+ "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
- basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d"
+ .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
return command.build()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 4d1bb1c639..2699f0b33e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -17,23 +17,19 @@
package org.apache.spark.scheduler.local
-import java.io.File
-import java.lang.management.ManagementFactory
-import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import akka.actor._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.ExecutorURLClassLoader
+import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-import akka.actor._
-import org.apache.spark.util.Utils
+
/**
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -41,52 +37,57 @@ import org.apache.spark.util.Utils
* testing fault recovery.
*/
-private[spark]
+private[local]
case class LocalReviveOffers()
-private[spark]
+private[local]
case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+private[local]
+case class KillTask(taskId: Long)
+
private[spark]
-class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
+ extends Actor with Logging {
+
+ val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true)
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
+
case LocalStatusUpdate(taskId, state, serializeData) =>
- freeCores += 1
- localScheduler.statusUpdate(taskId, state, serializeData)
- launchTask(localScheduler.resourceOffer(freeCores))
+ if (TaskState.isFinished(state)) {
+ freeCores += 1
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ case KillTask(taskId) =>
+ executor.killTask(taskId)
}
- def launchTask(tasks : Seq[TaskDescription]) {
+ private def launchTask(tasks: Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
- localScheduler.threadPool.submit(new Runnable {
- def run() {
- localScheduler.runTask(task.taskId, task.serializedTask)
- }
- })
+ executor.launchTask(localScheduler, task.taskId, task.serializedTask)
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
+ with ExecutorBackend
with Logging {
- var attemptId = new AtomicInteger(0)
- var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
- var listener: TaskSchedulerListener = null
+ val attemptId = new AtomicInteger
+ var dagScheduler: DAGScheduler = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
- val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
-
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
@@ -113,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
override def submitTasks(taskSet: TaskSet) {
@@ -127,6 +128,26 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ localActor ! KillTask(tid)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
@@ -166,107 +187,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
- def runTask(taskId: Long, bytes: ByteBuffer) {
- logInfo("Running " + taskId)
- val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
- // Set the Spark execution environment for the worker thread
- SparkEnv.set(env)
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objectSer = SparkEnv.get.serializer.newInstance()
- var attemptedTask: Option[Task[_]] = None
- val start = System.currentTimeMillis()
- var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
-
- try {
- Accumulators.clear()
- Thread.currentThread().setContextClassLoader(classLoader)
-
- // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
- // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
- updateDependencies(taskFiles, taskJars) // Download any files added with addFile
- val deserializedTask = ser.deserialize[Task[_]](
- taskBytes, Thread.currentThread.getContextClassLoader)
- attemptedTask = Some(deserializedTask)
- val deserTime = System.currentTimeMillis() - start
- taskStart = System.currentTimeMillis()
-
- // Run it
- val result: Any = deserializedTask.run(taskId)
-
- // Serialize and deserialize the result to emulate what the Mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val serResult = objectSer.serialize(result)
- deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = objectSer.deserialize[Any](serResult)
- val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
- ser.serialize(Accumulators.values))
- val serviceTime = System.currentTimeMillis() - taskStart
- logInfo("Finished " + taskId)
- deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
- deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
- deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
- val taskResult = new DirectTaskResult(
- result, accumUpdates, deserializedTask.metrics.getOrElse(null))
- val serializedResult = ser.serialize(taskResult)
- localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
- } catch {
- case t: Throwable => {
- val serviceTime = System.currentTimeMillis() - taskStart
- val metrics = attemptedTask.flatMap(t => t.metrics)
- for (m <- metrics) {
- m.executorRunTime = serviceTime.toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
- }
- val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
- }
- }
- }
-
- /**
- * Download any missing dependencies if we receive a new set of files and JARs from the
- * SparkContext. Also adds any new JARs we fetched to the class loader.
- */
- private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- synchronized {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentFiles(name) = timestamp
- }
-
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
- if (!classLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- classLoader.addURL(url)
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ if (TaskState.isFinished(state)) {
+ synchronized {
+ taskIdToTaskSetId.get(taskId) match {
+ case Some(taskSetId) =>
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+
+ state match {
+ case TaskState.FINISHED =>
+ taskSetManager.taskEnded(taskId, state, serializedData)
+ case TaskState.FAILED =>
+ taskSetManager.taskFailed(taskId, state, serializedData)
+ case TaskState.KILLED =>
+ taskSetManager.error("Task %d was killed".format(taskId))
+ case _ => {}
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
}
}
+ localActor ! LocalStatusUpdate(taskId, state, serializedData)
}
}
- def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
- synchronized {
- val taskSetId = taskIdToTaskSetId(taskId)
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
- taskSetManager.statusUpdate(taskId, state, serializedData)
- }
- }
-
- override def stop() {
- threadPool.shutdownNow()
+ override def stop() {
}
override def defaultParallelism() = threads
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index c2e2399ccb..53bf78267e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -132,19 +132,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- SparkEnv.set(env)
- state match {
- case TaskState.FINISHED =>
- taskEnded(tid, state, serializedData)
- case TaskState.FAILED =>
- taskFailed(tid, state, serializedData)
- case _ => {}
- }
- }
-
def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
@@ -159,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
}
result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
+ result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
@@ -176,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
+ sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
@@ -185,9 +175,9 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
- taskSet.id, index, 4, reason.description)
+ taskSet.id, index, MAX_TASK_FAILURES, reason.description)
decreaseRunningTasks(runningTasks)
- sched.listener.taskSetFailed(taskSet, errorMessage)
+ sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
@@ -195,5 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
override def error(message: String) {
+ sched.dagScheduler.taskSetFailed(taskSet, message)
+ sched.taskSetFinished(this)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index e936b1cfed..55b25f145a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -26,9 +26,8 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
-import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel}
-
import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
/**
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
@@ -43,13 +42,14 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
+ val blockId = TestBlockId("1")
// Register some commonly used classes
val toRegister: Seq[AnyRef] = Seq(
ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY,
- PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock("1", ByteBuffer.allocate(1)),
- GetBlock("1"),
+ PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
+ GotBlock(blockId, ByteBuffer.allocate(1)),
+ GetBlock(blockId),
1 to 10,
1 until 10,
1L to 10L,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
index 290dbce4f5..0d0a2dadc7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
@@ -18,5 +18,5 @@
package org.apache.spark.storage
private[spark]
-case class BlockException(blockId: String, message: String) extends Exception(message)
+case class BlockException(blockId: BlockId, message: String) extends Exception(message)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 3aeda3879d..e51c5b30a3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -47,7 +47,7 @@ import org.apache.spark.util.Utils
*/
private[storage]
-trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])]
with Logging with BlockFetchTracker {
def initialize()
}
@@ -57,20 +57,20 @@ private[storage]
object BlockFetcherIterator {
// A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
// the block (since we want all deserializaton to happen in the calling thread); can also
// represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1
}
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BlockFetcherIterator {
@@ -92,12 +92,12 @@ object BlockFetcherIterator {
// This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
- protected val localBlocksToFetch = new ArrayBuffer[String]()
+ protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
// This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
- protected val remoteBlocksToFetch = new HashSet[String]()
+ protected val remoteBlocksToFetch = new HashSet[BlockId]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@@ -167,7 +167,7 @@ object BlockFetcherIterator {
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
+ var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
@@ -183,7 +183,7 @@ object BlockFetcherIterator {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
+ curBlocks = new ArrayBuffer[(BlockId, Long)]
}
}
// Add in the final request
@@ -241,7 +241,7 @@ object BlockFetcherIterator {
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val startFetchWait = System.currentTimeMillis()
val result = results.take()
@@ -267,7 +267,7 @@ object BlockFetcherIterator {
class NettyBlockFetcherIterator(
blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
@@ -303,7 +303,7 @@ object BlockFetcherIterator {
override protected def sendRequest(req: FetchRequest) {
- def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
val fetchResult = new FetchResult(blockId, blockSize,
() => dataDeserialize(blockId, blockData.nioBuffer, serializer))
results.put(fetchResult)
@@ -337,7 +337,7 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
// If all the results has been retrieved, copiers will exit automatically
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
new file mode 100644
index 0000000000..7156d855d8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+/**
+ * Identifies a particular Block of data, usually associated with a single file.
+ * A Block can be uniquely identified by its filename, but each type of Block has a different
+ * set of keys which produce its unique name.
+ *
+ * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
+ */
+private[spark] sealed abstract class BlockId {
+ /** A globally unique identifier for this Block. Can be used for ser/de. */
+ def name: String
+
+ // convenience methods
+ def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
+ def isRDD = isInstanceOf[RDDBlockId]
+ def isShuffle = isInstanceOf[ShuffleBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
+
+ override def toString = name
+ override def hashCode = name.hashCode
+ override def equals(other: Any): Boolean = other match {
+ case o: BlockId => getClass == o.getClass && name.equals(o.name)
+ case _ => false
+ }
+}
+
+private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
+ def name = "rdd_" + rddId + "_" + splitIndex
+}
+
+private[spark]
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
+}
+
+private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
+ def name = "broadcast_" + broadcastId
+}
+
+private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+ def name = broadcastId.name + "_" + hType
+}
+
+private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
+ def name = "taskresult_" + taskId
+}
+
+private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
+ def name = "input-" + streamId + "-" + uniqueId
+}
+
+// Intended only for testing purposes
+private[spark] case class TestBlockId(id: String) extends BlockId {
+ def name = "test_" + id
+}
+
+private[spark] object BlockId {
+ val RDD = "rdd_([0-9]+)_([0-9]+)".r
+ val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val BROADCAST = "broadcast_([0-9]+)".r
+ val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
+ val TASKRESULT = "taskresult_([0-9]+)".r
+ val STREAM = "input-([0-9]+)-([0-9]+)".r
+ val TEST = "test_(.*)".r
+
+ /** Converts a BlockId "name" String back into a BlockId. */
+ def apply(id: String) = id match {
+ case RDD(rddId, splitIndex) =>
+ RDDBlockId(rddId.toInt, splitIndex.toInt)
+ case SHUFFLE(shuffleId, mapId, reduceId) =>
+ ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case BROADCAST(broadcastId) =>
+ BroadcastBlockId(broadcastId.toLong)
+ case BROADCAST_HELPER(broadcastId, hType) =>
+ BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
+ case TASKRESULT(taskId) =>
+ TaskResultBlockId(taskId.toLong)
+ case STREAM(streamId, uniqueId) =>
+ StreamBlockId(streamId.toInt, uniqueId.toLong)
+ case TEST(value) =>
+ TestBlockId(value)
+ case _ =>
+ throw new IllegalStateException("Unrecognized BlockId: " + id)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
new file mode 100644
index 0000000000..c8f397609a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.ConcurrentHashMap
+
+private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ // To save space, 'pending' and 'failed' are encoded as special sizes:
+ @volatile var size: Long = BlockInfo.BLOCK_PENDING
+ private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
+ private def failed: Boolean = size == BlockInfo.BLOCK_FAILED
+ private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this)
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread())
+ }
+
+ /**
+ * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
+ * Return true if the block is available, false otherwise.
+ */
+ def waitForReady(): Boolean = {
+ if (pending && initThread != Thread.currentThread()) {
+ synchronized {
+ while (pending) this.wait()
+ }
+ }
+ !failed
+ }
+
+ /** Mark this BlockInfo as ready (i.e. block is finished writing) */
+ def markReady(sizeInBytes: Long) {
+ require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes)
+ assert (pending)
+ size = sizeInBytes
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+
+ /** Mark this BlockInfo as ready but failed */
+ def markFailure() {
+ assert (pending)
+ size = BlockInfo.BLOCK_FAILED
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+}
+
+private object BlockInfo {
+ // initThread is logically a BlockInfo field, but we store it here because
+ // it's only needed while this block is in the 'pending' state and we want
+ // to minimize BlockInfo's memory footprint.
+ private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread]
+
+ private val BLOCK_PENDING: Long = -1L
+ private val BLOCK_FAILED: Long = -2L
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 7852849ce5..252329c4e1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,17 +17,18 @@
package org.apache.spark.storage
-import java.io.{InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+import scala.collection.mutable.{HashMap, ArrayBuffer}
+import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration.Duration
import scala.concurrent.duration._
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@@ -37,7 +38,6 @@ import org.apache.spark.util._
import sun.nio.ch.DirectBuffer
-
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -46,74 +46,20 @@ private[spark] class BlockManager(
maxMemory: Long)
extends Logging {
- private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- @volatile var pending: Boolean = true
- @volatile var size: Long = -1L
- @volatile var initThread: Thread = null
- @volatile var failed = false
-
- setInitThread()
-
- private def setInitThread() {
- // Set current thread as init thread - waitForReady will not block this thread
- // (in case there is non trivial initialization which ends up calling waitForReady as part of
- // initialization itself)
- this.initThread = Thread.currentThread()
- }
-
- /**
- * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
- * Return true if the block is available, false otherwise.
- */
- def waitForReady(): Boolean = {
- if (initThread != Thread.currentThread() && pending) {
- synchronized {
- while (pending) this.wait()
- }
- }
- !failed
- }
-
- /** Mark this BlockInfo as ready (i.e. block is finished writing) */
- def markReady(sizeInBytes: Long) {
- assert (pending)
- size = sizeInBytes
- initThread = null
- failed = false
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
-
- /** Mark this BlockInfo as ready but failed */
- def markFailure() {
- assert (pending)
- size = 0
- initThread = null
- failed = true
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
- }
-
val shuffleBlockManager = new ShuffleBlockManager(this)
+ val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
- private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
+ private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: DiskStore =
- new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
- if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@@ -154,7 +100,8 @@ private[spark] class BlockManager(
var heartBeatTask: Cancellable = null
- val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks)
+ private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks)
+ private val broadcastCleaner = new MetadataCleaner(MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks)
initialize()
// The compression codec to use. Note that the "lazy" val is necessary because we want to delay
@@ -248,7 +195,7 @@ private[spark] class BlockManager(
/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
- def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+ def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
/**
* Tell the master about the current storage status of a block. This will send a block update
@@ -258,7 +205,7 @@ private[spark] class BlockManager(
* droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
* This ensures that update in master will compensate for the increase in memory on slave.
*/
- def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) {
val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
@@ -269,11 +216,11 @@ private[spark] class BlockManager(
}
/**
- * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
+ private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@@ -298,7 +245,7 @@ private[spark] class BlockManager(
/**
* Get locations of an array of blocks.
*/
- def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = {
+ def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).toArray
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
@@ -310,7 +257,7 @@ private[spark] class BlockManager(
* shuffle blocks. It is safe to do so without a lock on block info since disk store
* never deletes (recent) items.
*/
- def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
diskStore.getValues(blockId, serializer).orElse(
sys.error("Block " + blockId + " not found on disk, though it should be"))
}
@@ -318,94 +265,19 @@ private[spark] class BlockManager(
/**
* Get block from local block manager.
*/
- def getLocal(blockId: String): Option[Iterator[Any]] = {
+ def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
- val info = blockInfo.get(blockId).orNull
- if (info != null) {
- info.synchronized {
-
- // In the another thread is writing the block, wait for it to become ready.
- if (!info.waitForReady()) {
- // If we get here, the block write failed.
- logWarning("Block " + blockId + " was marked as failure.")
- return None
- }
-
- val level = info.level
- logDebug("Level for block " + blockId + " is " + level)
-
- // Look for the block in memory
- if (level.useMemory) {
- logDebug("Getting block " + blockId + " from memory")
- memoryStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- logDebug("Block " + blockId + " not found in memory")
- }
- }
-
- // Look for block on disk, potentially loading it back into memory if required
- if (level.useDisk) {
- logDebug("Getting block " + blockId + " from disk")
- if (level.useMemory && level.deserialized) {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- // Put the block back in memory before returning it
- // TODO: Consider creating a putValues that also takes in a iterator ?
- val elements = new ArrayBuffer[Any]
- elements ++= iterator
- memoryStore.putValues(blockId, elements, level, true).data match {
- case Left(iterator2) =>
- return Some(iterator2)
- case _ =>
- throw new Exception("Memory store did not return back an iterator")
- }
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else if (level.useMemory && !level.deserialized) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- // Put a copy of the block back in memory before returning it. Note that we can't
- // put the ByteBuffer returned by the disk store as that's a memory-mapped file.
- // The use of rewind assumes this.
- assert (0 == bytes.position())
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- bytes.rewind()
- return Some(dataDeserialize(blockId, bytes))
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
- }
- }
- } else {
- logDebug("Block " + blockId + " not registered locally")
- }
- return None
+ doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
}
/**
* Get block from the local block manager as serialized bytes.
*/
- def getLocalBytes(blockId: String): Option[ByteBuffer] = {
- // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
+ def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
logDebug("Getting local block " + blockId + " as bytes")
-
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (ShuffleBlockManager.isShuffle(blockId)) {
+ if (blockId.isShuffle) {
return diskStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
@@ -413,12 +285,15 @@ private[spark] class BlockManager(
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
+ doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+ private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = {
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
- // In the another thread is writing the block, wait for it to become ready.
+ // If another thread is writing the block, wait for it to become ready.
if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure.")
@@ -431,62 +306,104 @@ private[spark] class BlockManager(
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
- memoryStore.getBytes(blockId) match {
- case Some(bytes) =>
- return Some(bytes)
+ val result = if (asValues) {
+ memoryStore.getValues(blockId)
+ } else {
+ memoryStore.getBytes(blockId)
+ }
+ result match {
+ case Some(values) =>
+ return Some(values)
case None =>
logDebug("Block " + blockId + " not found in memory")
}
}
- // Look for block on disk
+ // Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- assert (0 == bytes.position())
- if (level.useMemory) {
- if (level.deserialized) {
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- // The memory store will hang onto the ByteBuffer, so give it a copy instead of
- // the memory-mapped file buffer we got from the disk store
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- }
- }
- bytes.rewind()
- return Some(bytes)
+ logDebug("Getting block " + blockId + " from disk")
+ val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
+ case Some(bytes) => bytes
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
+ assert (0 == bytes.position())
+
+ if (!level.useMemory) {
+ // If the block shouldn't be stored in memory, we can just return it:
+ if (asValues) {
+ return Some(dataDeserialize(blockId, bytes))
+ } else {
+ return Some(bytes)
+ }
+ } else {
+ // Otherwise, we also have to store something in the memory store:
+ if (!level.deserialized || !asValues) {
+ // We'll store the bytes in memory if the block's storage level includes
+ // "memory serialized", or if it should be cached as objects in memory
+ // but we only requested its serialized bytes:
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ bytes.rewind()
+ }
+ if (!asValues) {
+ return Some(bytes)
+ } else {
+ val values = dataDeserialize(blockId, bytes)
+ if (level.deserialized) {
+ // Cache the values before returning them:
+ // TODO: Consider creating a putValues that also takes in a iterator?
+ val valuesBuffer = new ArrayBuffer[Any]
+ valuesBuffer ++= values
+ memoryStore.putValues(blockId, valuesBuffer, level, true).data match {
+ case Left(values2) =>
+ return Some(values2)
+ case _ =>
+ throw new Exception("Memory store did not return back an iterator")
+ }
+ } else {
+ return Some(values)
+ }
+ }
+ }
}
}
} else {
logDebug("Block " + blockId + " not registered locally")
}
- return None
+ None
}
/**
* Get block from remote block managers.
*/
- def getRemote(blockId: String): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
+ def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting remote block " + blockId)
- // Get locations of block
- val locations = master.getLocations(blockId)
+ doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
+ }
- // Get block from remote locations
+ /**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ logDebug("Getting remote block " + blockId + " as bytes")
+ doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+
+ private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = {
+ require(blockId != null, "BlockId is null")
+ val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
- return Some(dataDeserialize(blockId, data))
+ if (asValues) {
+ return Some(dataDeserialize(blockId, data))
+ } else {
+ return Some(data)
+ }
}
logDebug("The value of block " + blockId + " is null")
}
@@ -495,34 +412,9 @@ private[spark] class BlockManager(
}
/**
- * Get block from remote block managers as serialized bytes.
- */
- def getRemoteBytes(blockId: String): Option[ByteBuffer] = {
- // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
- // refactored.
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- logDebug("Getting remote block " + blockId + " as bytes")
-
- val locations = master.getLocations(blockId)
- for (loc <- locations) {
- logDebug("Getting remote block " + blockId + " from " + loc)
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
- if (data != null) {
- return Some(data)
- }
- logDebug("The value of block " + blockId + " is null")
- }
- logDebug("Block " + blockId + " not found")
- return None
- }
-
- /**
* Get a block from the block manager (either local or remote).
*/
- def get(blockId: String): Option[Iterator[Any]] = {
+ def get(blockId: BlockId): Option[Iterator[Any]] = {
val local = getLocal(blockId)
if (local.isDefined) {
logInfo("Found block %s locally".format(blockId))
@@ -543,7 +435,7 @@ private[spark] class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
val iter =
@@ -557,7 +449,7 @@ private[spark] class BlockManager(
iter
}
- def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
+ def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
: Long = {
val elements = new ArrayBuffer[Any]
elements ++= values
@@ -566,35 +458,38 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
+ * The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
- writer.registerCloseEventHandler(() => {
- val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
- blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.size())
- })
- writer
+ val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
}
/**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
- def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
- tellMaster: Boolean = true) : Long = {
+ def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
+ tellMaster: Boolean = true) : Long = {
+ require(values != null, "Values is null")
+ doPut(blockId, Left(values), level, tellMaster)
+ }
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (values == null) {
- throw new IllegalArgumentException("Values is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
+ /**
+ * Put a new block of serialized bytes to the block manager.
+ */
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
+ tellMaster: Boolean = true) {
+ require(bytes != null, "Bytes is null")
+ doPut(blockId, Right(bytes), level, tellMaster)
+ }
+
+ private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
+ level: StorageLevel, tellMaster: Boolean = true): Long = {
+ require(blockId != null, "BlockId is null")
+ require(level != null && level.isValid, "StorageLevel is null or invalid")
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
@@ -610,7 +505,8 @@ private[spark] class BlockManager(
return oldBlockOpt.get.size
}
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ // TODO: So the block info exists - but previous attempt to load it (?) failed.
+ // What do we do now ? Retry on it ?
oldBlockOpt.get
} else {
tinfo
@@ -619,10 +515,10 @@ private[spark] class BlockManager(
val startTimeMs = System.currentTimeMillis
- // If we need to replicate the data, we'll want access to the values, but because our
- // put will read the whole iterator, there will be no values left. For the case where
- // the put serializes data, we'll remember the bytes, above; but for the case where it
- // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
+ // If we're storing values and we need to replicate the data, we'll want access to the values,
+ // but because our put will read the whole iterator, there will be no values left. For the
+ // case where the put serializes data, we'll remember the bytes, above; but for the case where
+ // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
// Ditto for the bytes after the put
@@ -631,30 +527,51 @@ private[spark] class BlockManager(
// Size of the block in bytes (to return to caller)
var size = 0L
+ // If we're storing bytes, then initiate the replication before storing them locally.
+ // This is faster as data is already serialized and ready to send.
+ val replicationFuture = if (data.isRight && level.replication > 1) {
+ val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ Future {
+ replicate(blockId, bufferView, level)
+ }
+ } else {
+ null
+ }
+
myInfo.synchronized {
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will later
- // drop it to disk if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
+ data match {
+ case Left(values) => {
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
+ }
}
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
+ case Right(bytes) => {
+ bytes.rewind()
+ // Store it only in memory at first, even if useDisk is also set to true
+ (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
+ size = bytes.limit
}
}
@@ -679,132 +596,46 @@ private[spark] class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
- // Replicate block if required
+ // Either we're storing bytes and we asynchronously started replication, or we're storing
+ // values and need to serialize and replicate them now:
if (level.replication > 1) {
- val remoteStartTime = System.currentTimeMillis
- // Serialize the block if not already done
- if (bytesAfterPut == null) {
- if (valuesAfterPut == null) {
- throw new SparkException(
- "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
- }
- bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
- }
- replicate(blockId, bytesAfterPut, level)
- logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
- }
- BlockManager.dispose(bytesAfterPut)
-
- return size
- }
-
-
- /**
- * Put a new block of serialized bytes to the block manager.
- */
- def putBytes(
- blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
-
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (bytes == null) {
- throw new IllegalArgumentException("Bytes is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
-
- // Remember the block's storage level so that we can correctly drop it to disk if it needs
- // to be dropped right after it got put into memory. Note, however, that other threads will
- // not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = {
- val tinfo = new BlockInfo(level, tellMaster)
- // Do atomically !
- val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
-
- if (oldBlockOpt.isDefined) {
- if (oldBlockOpt.get.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
- oldBlockOpt.get
- } else {
- tinfo
- }
- }
-
- val startTimeMs = System.currentTimeMillis
-
- // Initiate the replication before storing it locally. This is faster as
- // data is already serialized and ready for sending
- val replicationFuture = if (level.replication > 1) {
- val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
- Future {
- replicate(blockId, bufferView, level)
- }
- } else {
- null
- }
-
- myInfo.synchronized {
- logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
- + " to get into synchronized block")
-
- var marked = false
- try {
- if (level.useMemory) {
- // Store it only in memory at first, even if useDisk is also set to true
- bytes.rewind()
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- bytes.rewind()
- diskStore.putBytes(blockId, bytes, level)
- }
-
- // assert (0 == bytes.position(), "" + bytes)
-
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- marked = true
- myInfo.markReady(bytes.limit)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
- }
- } finally {
- // If we failed at putting the block to memory/disk, notify other possible readers
- // that it has failed, and then remove it from the block info map.
- if (! marked) {
- // Note that the remove must happen before markFailure otherwise another thread
- // could've inserted a new BlockInfo before we remove it.
- blockInfo.remove(blockId)
- myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed")
+ data match {
+ case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case Left(values) => {
+ val remoteStartTime = System.currentTimeMillis
+ // Serialize the block if not already done
+ if (bytesAfterPut == null) {
+ if (valuesAfterPut == null) {
+ throw new SparkException(
+ "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ }
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ }
+ replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " +
+ Utils.getUsedTimeMs(remoteStartTime))
}
}
}
- // If replication had started, then wait for it to finish
- if (level.replication > 1) {
- Await.ready(replicationFuture, Duration.Inf)
- }
+ BlockManager.dispose(bytesAfterPut)
if (level.replication > 1) {
- logDebug("PutBytes for block " + blockId + " with replication took " +
+ logDebug("Put for block " + blockId + " with replication took " +
Utils.getUsedTimeMs(startTimeMs))
} else {
- logDebug("PutBytes for block " + blockId + " without replication took " +
+ logDebug("Put for block " + blockId + " without replication took " +
Utils.getUsedTimeMs(startTimeMs))
}
+
+ size
}
/**
* Replicate block to another node.
*/
var cachedPeers: Seq[BlockManagerId] = null
- private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+ private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) {
val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
@@ -827,14 +658,14 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
- def getSingle(blockId: String): Option[Any] = {
+ def getSingle(blockId: BlockId): Option[Any] = {
get(blockId).map(_.next())
}
/**
* Write a block consisting of a single object.
*/
- def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
+ def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
put(blockId, Iterator(value), level, tellMaster)
}
@@ -842,7 +673,7 @@ private[spark] class BlockManager(
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space.
*/
- def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
+ def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) {
logInfo("Dropping block " + blockId + " from memory")
val info = blockInfo.get(blockId).orNull
if (info != null) {
@@ -891,16 +722,15 @@ private[spark] class BlockManager(
// TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
// from RDD.id to blocks.
logInfo("Removing RDD " + rddId)
- val rddPrefix = "rdd_" + rddId + "_"
- val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1)
- blocksToRemove.foreach(blockId => removeBlock(blockId, false))
+ val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
+ blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
blocksToRemove.size
}
/**
* Remove a block from both memory and disk.
*/
- def removeBlock(blockId: String, tellMaster: Boolean = true) {
+ def removeBlock(blockId: BlockId, tellMaster: Boolean = true) {
logInfo("Removing block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) info.synchronized {
@@ -921,13 +751,22 @@ private[spark] class BlockManager(
}
}
- def dropOldBlocks(cleanupTime: Long) {
- logInfo("Dropping blocks older than " + cleanupTime)
+ private def dropOldNonBroadcastBlocks(cleanupTime: Long) {
+ logInfo("Dropping non broadcast blocks older than " + cleanupTime)
+ dropOldBlocks(cleanupTime, !_.isBroadcast)
+ }
+
+ private def dropOldBroadcastBlocks(cleanupTime: Long) {
+ logInfo("Dropping broadcast blocks older than " + cleanupTime)
+ dropOldBlocks(cleanupTime, _.isBroadcast)
+ }
+
+ private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime) {
+ if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
if (level.useMemory) {
@@ -944,39 +783,45 @@ private[spark] class BlockManager(
}
}
- def shouldCompress(blockId: String): Boolean = {
- if (ShuffleBlockManager.isShuffle(blockId)) {
- compressShuffle
- } else if (blockId.startsWith("broadcast_")) {
- compressBroadcast
- } else if (blockId.startsWith("rdd_")) {
- compressRdds
- } else {
- false // Won't happen in a real cluster, but it can in tests
- }
+ def shouldCompress(blockId: BlockId): Boolean = blockId match {
+ case ShuffleBlockId(_, _, _) => compressShuffle
+ case BroadcastBlockId(_) => compressBroadcast
+ case RDDBlockId(_, _) => compressRdds
+ case _ => false
}
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
+ def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}
/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: InputStream): InputStream = {
+ def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
+ /** Serializes into a stream. */
+ def dataSerializeStream(
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer) {
+ val byteStream = new FastBufferedOutputStream(outputStream)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a byte buffer. */
def dataSerialize(
- blockId: String,
+ blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
@@ -986,7 +831,7 @@ private[spark] class BlockManager(
* the iterator is reached.
*/
def dataDeserialize(
- blockId: String,
+ blockId: BlockId,
bytes: ByteBuffer,
serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
@@ -1004,6 +849,7 @@ private[spark] class BlockManager(
memoryStore.clear()
diskStore.clear()
metadataCleaner.cancel()
+ broadcastCleaner.cancel()
logInfo("BlockManager stopped")
}
}
@@ -1041,10 +887,10 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToBlockManagers(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[BlockManagerId]] =
+ : Map[BlockId, Seq[BlockManagerId]] =
{
// env == null and blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
@@ -1054,7 +900,7 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster.getLocations(blockIds)
}
- val blockManagers = new HashMap[String, Seq[BlockManagerId]]
+ val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]]
for (i <- 0 until blockIds.length) {
blockManagers(blockIds(i)) = blockLocations(i)
}
@@ -1062,19 +908,19 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToExecutorIds(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
}
def blockIdsToHosts(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 0c977f05d1..48d7101b0a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -69,7 +69,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
@@ -80,12 +80,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/** Get locations of the blockId from the driver */
- def getLocations(blockId: String): Seq[BlockManagerId] = {
+ def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the driver */
- def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
@@ -103,7 +103,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
*/
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
askDriverWithReply(RemoveBlock(blockId))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 3776951782..154a3980e9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
- private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]]
+ private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
val akkaTimeout = Duration.create(
System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
@@ -130,10 +130,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// First remove the metadata for the given RDD, and then asynchronously remove the blocks
// from the slaves.
- val prefix = "rdd_" + rddId + "_"
// Find all blocks for the given RDD, remove the block from both blockLocations and
// the blockManagerInfo that is tracking the blocks.
- val blocks = blockLocations.keySet().filter(_.startsWith(prefix))
+ val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
blocks.foreach { blockId =>
val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
@@ -199,7 +198,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- private def removeBlockFromWorkers(blockId: String) {
+ private def removeBlockFromWorkers(blockId: BlockId) {
val locations = blockLocations.get(blockId)
if (locations != null) {
locations.foreach { blockManagerId: BlockManagerId =>
@@ -229,9 +228,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- if (id.executorId == "<driver>" && !isLocal) {
- // Got a register message from the master node; don't register it
- } else if (!blockManagerInfo.contains(id)) {
+ if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.
@@ -248,7 +245,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long) {
@@ -293,11 +290,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true
}
- private def getLocations(blockId: String): Seq[BlockManagerId] = {
+ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
blockIds.map(blockId => getLocations(blockId))
}
@@ -331,7 +328,7 @@ object BlockManagerMasterActor {
private var _remainingMem: Long = maxMem
// Mapping from block id to its status.
- private val _blocks = new JHashMap[String, BlockStatus]
+ private val _blocks = new JHashMap[BlockId, BlockStatus]
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem)))
@@ -340,7 +337,7 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis()
}
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
+ def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long,
diskSize: Long) {
updateLastSeenMs()
@@ -384,7 +381,7 @@ object BlockManagerMasterActor {
}
}
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
if (_blocks.containsKey(blockId)) {
_remainingMem += _blocks.get(blockId).memSize
_blocks.remove(blockId)
@@ -395,7 +392,7 @@ object BlockManagerMasterActor {
def lastSeenMs: Long = _lastSeenMs
- def blocks: JHashMap[String, BlockStatus] = _blocks
+ def blocks: JHashMap[BlockId, BlockStatus] = _blocks
override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 24333a179c..45f51da288 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+ case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
@@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages {
class UpdateBlockInfo(
var blockManagerId: BlockManagerId,
- var blockId: String,
+ var blockId: BlockId,
var storageLevel: StorageLevel,
var memSize: Long,
var diskSize: Long)
@@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages {
override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out)
- out.writeUTF(blockId)
+ out.writeUTF(blockId.name)
storageLevel.writeExternal(out)
out.writeLong(memSize)
out.writeLong(diskSize)
@@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages {
override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in)
- blockId = in.readUTF()
+ blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()
@@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages {
object UpdateBlockInfo {
def apply(blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): UpdateBlockInfo = {
@@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages {
}
// For pattern-matching
- def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = {
Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
}
}
- case class GetLocations(blockId: String) extends ToBlockManagerMaster
+ case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster
- case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
+ case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index 951503019f..3a65e55733 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._
* An actor to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
+private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index 678c38203c..0c66addf9d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
}
}
- private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
blockManager.putBytes(id, bytes, level)
@@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
+ " with data size: " + bytes.limit)
}
- private def getBlock(id: String): ByteBuffer = {
+ private def getBlock(id: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + id + " started from " + startTimeMs)
val buffer = blockManager.getLocalBytes(id) match {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
index d8fa6a91d1..80dcb5a207 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.network._
-private[spark] case class GetBlock(id: String)
-private[spark] case class GotBlock(id: String, data: ByteBuffer)
-private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+private[spark] case class GetBlock(id: BlockId)
+private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
+private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
private[spark] class BlockMessage() {
// Un-initialized: typ = 0
@@ -34,7 +34,7 @@ private[spark] class BlockMessage() {
// GotBlock: typ = 2
// PutBlock: typ = 3
private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
- private var id: String = null
+ private var id: BlockId = null
private var data: ByteBuffer = null
private var level: StorageLevel = null
@@ -74,7 +74,7 @@ private[spark] class BlockMessage() {
for (i <- 1 to idLength) {
idBuilder += buffer.getChar()
}
- id = idBuilder.toString()
+ id = BlockId(idBuilder.toString)
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
@@ -109,28 +109,17 @@ private[spark] class BlockMessage() {
set(buffer)
}
- def getType: Int = {
- return typ
- }
-
- def getId: String = {
- return id
- }
-
- def getData: ByteBuffer = {
- return data
- }
-
- def getLevel: StorageLevel = {
- return level
- }
+ def getType: Int = typ
+ def getId: BlockId = id
+ def getData: ByteBuffer = data
+ def getLevel: StorageLevel = level
def toBufferMessage: BufferMessage = {
val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]()
- var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
- buffer.putInt(typ).putInt(id.length())
- id.foreach((x: Char) => buffer.putChar(x))
+ var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
+ buffer.putInt(typ).putInt(id.name.length)
+ id.name.foreach((x: Char) => buffer.putChar(x))
buffer.flip()
buffers += buffer
@@ -212,7 +201,8 @@ private[spark] object BlockMessage {
def main(args: Array[String]) {
val B = new BlockMessage()
- B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
+ val blockId = TestBlockId("ABC")
+ B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index 0aaf846b5b..6ce9127c74 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -111,14 +111,15 @@ private[spark] object BlockMessageArray {
}
def main(args: Array[String]) {
- val blockMessages =
+ val blockMessages =
(0 until 10).map { i =>
if (i % 2 == 0) {
val buffer = ByteBuffer.allocate(100)
buffer.clear
- BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER))
+ BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
+ StorageLevel.MEMORY_ONLY_SER))
} else {
- BlockMessage.fromGetBlock(GetBlock(i.toString))
+ BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
}
}
val blockMessageArray = new BlockMessageArray(blockMessages)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 39f103297f..469e68fed7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,6 +17,13 @@
package org.apache.spark.storage
+import java.io.{FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -25,22 +32,14 @@ package org.apache.spark.storage
*
* This interface does not support concurrent writes.
*/
-abstract class BlockObjectWriter(val blockId: String) {
-
- var closeEventHandler: () => Unit = _
+abstract class BlockObjectWriter(val blockId: BlockId) {
def open(): BlockObjectWriter
- def close() {
- closeEventHandler()
- }
+ def close()
def isOpen: Boolean
- def registerCloseEventHandler(handler: () => Unit) {
- closeEventHandler = handler
- }
-
/**
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
@@ -59,7 +58,126 @@ abstract class BlockObjectWriter(val blockId: String) {
def write(value: Any)
/**
- * Size of the valid writes, in bytes.
+ * Returns the file segment of committed data that this Writer has written.
+ */
+ def fileSegment(): FileSegment
+
+ /**
+ * Cumulative time spent performing blocking writes, in ns.
*/
- def size(): Long
+ def timeWriting(): Long
+}
+
+/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+class DiskBlockObjectWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
+ extends BlockObjectWriter(blockId)
+ with Logging
+{
+
+ /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
+ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
+ def timeWriting = _timeWriting
+ private var _timeWriting = 0L
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ _timeWriting += (System.nanoTime() - start)
+ }
+
+ def write(i: Int): Unit = callWithTiming(out.write(i))
+ override def write(b: Array[Byte]) = callWithTiming(out.write(b))
+ override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ }
+
+ private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
+
+ /** The file channel, used for repositioning / truncating the file. */
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var fos: FileOutputStream = null
+ private var ts: TimeTrackingOutputStream = null
+ private var objOut: SerializationStream = null
+ private val initialPosition = file.length()
+ private var lastValidPosition = initialPosition
+ private var initialized = false
+ private var _timeWriting = 0L
+
+ override def open(): BlockObjectWriter = {
+ fos = new FileOutputStream(file, true)
+ ts = new TimeTrackingOutputStream(fos)
+ channel = fos.getChannel()
+ lastValidPosition = initialPosition
+ bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ objOut.flush()
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ _timeWriting += System.nanoTime() - start
+ }
+ objOut.close()
+
+ _timeWriting += ts.timeWriting
+
+ channel = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ }
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def fileSegment(): FileSegment = {
+ val bytesWritten = lastValidPosition - initialPosition
+ new FileSegment(file, initialPosition, bytesWritten)
+ }
+
+ // Only valid if called after close()
+ override def timeWriting() = _timeWriting
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index fa834371f4..ea42656240 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -27,7 +27,7 @@ import org.apache.spark.Logging
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
- def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
/**
* Return the size of a block in bytes.
*/
- def getSize(blockId: String): Long
+ def getSize(blockId: BlockId): Long
- def getBytes(blockId: String): Option[ByteBuffer]
+ def getBytes(blockId: BlockId): Option[ByteBuffer]
- def getValues(blockId: String): Option[Iterator[Any]]
+ def getValues(blockId: BlockId): Option[Iterator[Any]]
/**
* Remove a block, if it exists.
* @param blockId the block to remove.
* @return True if the block was found and removed, False otherwise.
*/
- def remove(blockId: String): Boolean
+ def remove(blockId: BlockId): Boolean
- def contains(blockId: String): Boolean
+ def contains(blockId: BlockId): Boolean
def clear() { }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
new file mode 100644
index 0000000000..fcd2e97982
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.util.Utils
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and physical on-disk
+ * locations. By default, one block is mapped to one file with a name given by its BlockId.
+ * However, it is also possible to have a block map to only a segment of a file, by calling
+ * mapBlockToFileSegment().
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+ extends PathResolver with Logging {
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private var shuffleSender : ShuffleSender = null
+
+ addShutdownHook()
+
+ /**
+ * Returns the phyiscal file segment in which the given BlockId is located.
+ * If the BlockId has been mapped to a specific FileSegment, that will be returned.
+ * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ */
+ def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
+ shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
+ } else {
+ val file = getFile(blockId.name)
+ new FileSegment(file, 0, file.length())
+ }
+ }
+
+ def getFile(filename: String): File = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, filename)
+ }
+
+ def getFile(blockId: BlockId): File = getFile(blockId.name)
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ shuffleSender = new ShuffleSender(port, this)
+ logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 63447baf8c..5a1e7b4444 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,153 +17,46 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
-import java.util.{Random, Date}
-import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
-import org.apache.spark.network.netty.ShuffleSender
-import org.apache.spark.network.netty.PathResolver
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
-private class DiskStore(blockManager: BlockManager, rootDirs: String)
+private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
- extends BlockObjectWriter(blockId) {
-
- private val f: File = createFile(blockId /*, allowAppendExisting */)
-
- // The file channel, used for repositioning / truncating the file.
- private var channel: FileChannel = null
- private var bs: OutputStream = null
- private var objOut: SerializationStream = null
- private var lastValidPosition = 0L
- private var initialized = false
-
- override def open(): DiskBlockObjectWriter = {
- val fos = new FileOutputStream(f, true)
- channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
- objOut = serializer.newInstance().serializeStream(bs)
- initialized = true
- this
- }
-
- override def close() {
- if (initialized) {
- objOut.close()
- channel = null
- bs = null
- objOut = null
- }
- // Invoke the close callback handler.
- super.close()
- }
-
- override def isOpen: Boolean = objOut != null
-
- // Flush the partial writes, and set valid length to be the length of the entire file.
- // Return the number of bytes written for this commit.
- override def commit(): Long = {
- if (initialized) {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
- } else {
- // lastValidPosition is zero if stream is uninitialized
- lastValidPosition
- }
- }
-
- override def revertPartialWrites() {
- if (initialized) {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
- }
- }
-
- override def write(value: Any) {
- if (!initialized) {
- open()
- }
- objOut.writeObject(value)
- }
-
- override def size(): Long = lastValidPosition
- }
-
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
-
- private var shuffleSender : ShuffleSender = null
- // Create one local directory for each path mentioned in spark.local.dir; then, inside this
- // directory, create multiple subdirectories that we will hash files into, in order to avoid
- // having really large inodes at the top level.
- private val localDirs: Array[File] = createLocalDirs()
- private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
-
- addShutdownHook()
-
- def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId, serializer, bufferSize)
+ override def getSize(blockId: BlockId): Long = {
+ diskManager.getBlockLocation(blockId).length
}
- override def getSize(blockId: String): Long = {
- getFile(blockId).length()
- }
-
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val channel = new RandomAccessFile(file, "rw").getChannel()
+ val file = diskManager.getFile(blockId)
+ val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
- }
-
- private def getFileBytes(file: File): ByteBuffer = {
- val length = file.length()
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = try {
- channel.map(MapMode.READ_ONLY, 0, length)
- } finally {
- channel.close()
- }
-
- buffer
+ file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -171,159 +64,62 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(blockId,
- new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
- objOut.writeAll(values.iterator)
- objOut.close()
- val length = file.length()
+ val file = diskManager.getFile(blockId)
+ val outputStream = new FileOutputStream(file)
+ blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(length), timeTaken))
+ file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val buffer = getFileBytes(file)
+ val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val bytes = getFileBytes(file)
- Some(bytes)
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ val segment = diskManager.getBlockLocation(blockId)
+ val channel = new RandomAccessFile(segment.file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
+ } finally {
+ channel.close()
+ }
+ Some(buffer)
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
* A version of getValues that allows a custom serializer. This is used as part of the
* shuffle short-circuit code.
*/
- def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
}
- override def remove(blockId: String): Boolean = {
- val file = getFile(blockId)
- if (file.exists()) {
+ override def remove(blockId: BlockId): Boolean = {
+ val fileSegment = diskManager.getBlockLocation(blockId)
+ val file = fileSegment.file
+ if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
- false
- }
- }
-
- override def contains(blockId: String): Boolean = {
- getFile(blockId).exists()
- }
-
- private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
- val file = getFile(blockId)
- if (!allowAppendExisting && file.exists()) {
- // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
- // was rescheduled on the same machine as the old task.
- logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
- file.delete()
- }
- file
- }
-
- private def getFile(blockId: String): File = {
- logDebug("Getting file for block " + blockId)
-
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
-
- new File(subDir, blockId)
- }
-
- private def createLocalDirs(): Array[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, "spark-local-" + localDirId)
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
- }
+ if (fileSegment.length < file.length()) {
+ logWarning("Could not delete block associated with only a part of a file: " + blockId)
}
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created local directory at " + localDir)
- localDir
+ false
}
}
- private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
- if (shuffleSender != null) {
- shuffleSender.stop
- }
- }
- })
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- val pResolver = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!blockId.startsWith("shuffle_")) {
- return null
- }
- DiskStore.this.getFile(blockId).getAbsolutePath()
- }
- }
- shuffleSender = new ShuffleSender(port, pResolver)
- logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
- shuffleSender.port
+ override def contains(blockId: BlockId): Boolean = {
+ val file = diskManager.getBlockLocation(blockId).file
+ file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
new file mode 100644
index 0000000000..555486830a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+/**
+ * References a particular segment of a file (potentially the entire file),
+ * based off an offset and a length.
+ */
+private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 77a39c71ed..05f676c6e2 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -32,7 +32,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
case class Entry(value: Any, size: Long, deserialized: Boolean)
- private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
+ private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true)
@volatile private var currentMemory = 0L
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store.
@@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
def freeMemory: Long = maxMemory - currentMemory
- override def getSize(blockId: String): Long = {
+ override def getSize(blockId: BlockId): Long = {
entries.synchronized {
entries.get(blockId).size
}
}
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -108,7 +108,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def remove(blockId: String): Boolean = {
+ override def remove(blockId: BlockId): Boolean = {
entries.synchronized {
val entry = entries.remove(blockId)
if (entry != null) {
@@ -131,14 +131,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
/**
- * Return the RDD ID that a given block ID is from, or null if it is not an RDD block.
+ * Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
*/
- private def getRddId(blockId: String): String = {
- if (blockId.startsWith("rdd_")) {
- blockId.split('_')(1)
- } else {
- null
- }
+ private def getRddId(blockId: BlockId): Option[Int] = {
+ blockId.asRDDId.map(_.rddId)
}
/**
@@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* blocks to free memory for one block, another thread may use up the freed space for
* another block.
*/
- private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = {
+ private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = {
// TODO: Its possible to optimize the locking by locking entries only when selecting blocks
// to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been
// released, it must be ensured that those to-be-dropped blocks are not double counted for
@@ -195,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Assumes that a lock is held by the caller to ensure only one thread is dropping blocks.
* Otherwise, the freed space may fill up before the caller puts in their new value.
*/
- private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = {
+ private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory))
@@ -207,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - currentMemory < space) {
val rddToAdd = getRddId(blockIdToAdd)
- val selectedBlocks = new ArrayBuffer[String]()
+ val selectedBlocks = new ArrayBuffer[BlockId]()
var selectedMemory = 0L
// This is synchronized to ensure that the set of entries is not changed
@@ -218,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
val pair = iterator.next()
val blockId = pair.getKey
- if (rddToAdd != null && rddToAdd == getRddId(blockId)) {
+ if (rddToAdd != None && rddToAdd == getRddId(blockId)) {
logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
"block from the same RDD")
return false
@@ -252,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
return true
}
- override def contains(blockId: String): Boolean = {
+ override def contains(blockId: BlockId): Boolean = {
entries.synchronized { entries.containsKey(blockId) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 9da11efb57..2f1b049ce4 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,51 +17,199 @@
package org.apache.spark.storage
-import org.apache.spark.serializer.Serializer
+import java.io.File
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
-private[spark]
-class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+ val writers: Array[BlockObjectWriter]
-private[spark]
-trait ShuffleBlocks {
- def acquireWriters(mapId: Int): ShuffleWriterGroup
- def releaseWriters(group: ShuffleWriterGroup)
+ /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
+ def releaseWriters(success: Boolean)
}
-
+/**
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
+ * per reducer (this set of files is called a ShuffleFileGroup).
+ *
+ * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
+ * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
+ * - shuffleId: The unique id given to the entire shuffle stage.
+ * - bucketId: The id of the output partition (i.e., reducer id)
+ * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
+ * time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
+ * that specifies where in a given file the actual block data is located.
+ *
+ * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
+ */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
+ // TODO: Remove this once the shuffle file consolidation feature is stable.
+ val consolidateShuffleFiles =
+ System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
+
+ private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
+ /**
+ * Contains all the state related to a particular shuffle. This includes a pool of unused
+ * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
+ */
+ private class ShuffleState() {
+ val nextFileId = new AtomicInteger(0)
+ val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ }
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
- new ShuffleBlocks {
- // Get a group of writers for a map task.
- override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
- val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
- val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
- val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ type ShuffleId = Int
+ private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
+
+ private
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
+
+ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+ new ShuffleWriterGroup {
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+ private val shuffleState = shuffleStates(shuffleId)
+ private var fileGroup: ShuffleFileGroup = null
+
+ val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ fileGroup = getUnusedFileGroup()
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+ }
+ } else {
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ val blockFile = blockManager.diskBlockManager.getFile(blockId)
+ blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+ }
+ }
+
+ override def releaseWriters(success: Boolean) {
+ if (consolidateShuffleFiles) {
+ if (success) {
+ val offsets = writers.map(_.fileSegment().offset)
+ fileGroup.recordMapOutput(mapId, offsets)
+ }
+ recycleFileGroup(fileGroup)
}
- new ShuffleWriterGroup(mapId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ private def getUnusedFileGroup(): ShuffleFileGroup = {
+ val fileGroup = shuffleState.unusedFileGroups.poll()
+ if (fileGroup != null) fileGroup else newFileGroup()
+ }
+
+ private def newFileGroup(): ShuffleFileGroup = {
+ val fileId = shuffleState.nextFileId.getAndIncrement()
+ val files = Array.tabulate[File](numBuckets) { bucketId =>
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.diskBlockManager.getFile(filename)
+ }
+ val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+ shuffleState.allFileGroups.add(fileGroup)
+ fileGroup
+ }
+
+ private def recycleFileGroup(group: ShuffleFileGroup) {
+ shuffleState.unusedFileGroups.add(group)
}
}
}
-}
+ /**
+ * Returns the physical file segment in which the given BlockId is located.
+ * This function should only be called if shuffle file consolidation is enabled, as it is
+ * an error condition if we don't find the expected block.
+ */
+ def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+ // Search all file groups associated with this shuffle.
+ val shuffleState = shuffleStates(id.shuffleId)
+ for (fileGroup <- shuffleState.allFileGroups) {
+ val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+ if (segment.isDefined) { return segment.get }
+ }
+ throw new IllegalStateException("Failed to find shuffle block: " + id)
+ }
+
+ private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
+ "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ shuffleStates.clearOldValues(cleanupTime)
+ }
+}
private[spark]
object ShuffleBlockManager {
+ /**
+ * A group of shuffle files, one per reducer.
+ * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
+ */
+ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+ /**
+ * Stores the absolute index of each mapId in the files of this group. For instance,
+ * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+ */
+ private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
- // Returns the block id for a given shuffle block.
- def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
- "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
- }
+ /**
+ * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+ * This ordering allows us to compute block lengths by examining the following block offset.
+ * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+ * reducer.
+ */
+ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
+ new PrimitiveVector[Long]()
+ }
+
+ def numBlocks = mapIdToIndex.size
+
+ def apply(bucketId: Int) = files(bucketId)
+
+ def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+ mapIdToIndex(mapId) = numBlocks
+ for (i <- 0 until offsets.length) {
+ blockOffsetsByReducer(i) += offsets(i)
+ }
+ }
- // Returns true if the block is a shuffle block.
- def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
+ /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+ def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+ val file = files(reducerId)
+ val blockOffsets = blockOffsetsByReducer(reducerId)
+ val index = mapIdToIndex.getOrElse(mapId, -1)
+ if (index >= 0) {
+ val offset = blockOffsets(index)
+ val length =
+ if (index + 1 < numBlocks) {
+ blockOffsets(index + 1) - offset
+ } else {
+ file.length() - offset
+ }
+ assert(length >= 0)
+ Some(new FileSegment(file, offset, length))
+ } else {
+ None
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
new file mode 100644
index 0000000000..1e4db4f66b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -0,0 +1,86 @@
+package org.apache.spark.storage
+
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{CountDownLatch, Executors}
+
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+
+/**
+ * Utility for micro-benchmarking shuffle write performance.
+ *
+ * Writes simulated shuffle output from several threads and records the observed throughput.
+ */
+object StoragePerfTester {
+ def main(args: Array[String]) = {
+ /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
+ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g"))
+
+ /** Number of map tasks. All tasks execute concurrently. */
+ val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8)
+
+ /** Number of reduce splits for each map task. */
+ val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500)
+
+ val recordLength = 1000 // ~1KB records
+ val totalRecords = dataSizeMb * 1000
+ val recordsPerMap = totalRecords / numMaps
+
+ val writeData = "1" * recordLength
+ val executor = Executors.newFixedThreadPool(numMaps)
+
+ System.setProperty("spark.shuffle.compress", "false")
+ System.setProperty("spark.shuffle.sync", "true")
+
+ // This is only used to instantiate a BlockManager. All thread scheduling is done manually.
+ val sc = new SparkContext("local[4]", "Write Tester")
+ val blockManager = sc.env.blockManager
+
+ def writeOutputBytes(mapId: Int, total: AtomicLong) = {
+ val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
+ new KryoSerializer())
+ val writers = shuffle.writers
+ for (i <- 1 to recordsPerMap) {
+ writers(i % numOutputSplits).write(writeData)
+ }
+ writers.map {w =>
+ w.commit()
+ total.addAndGet(w.fileSegment().length)
+ w.close()
+ }
+
+ shuffle.releaseWriters(true)
+ }
+
+ val start = System.currentTimeMillis()
+ val latch = new CountDownLatch(numMaps)
+ val totalBytes = new AtomicLong()
+ for (task <- 1 to numMaps) {
+ executor.submit(new Runnable() {
+ override def run() = {
+ try {
+ writeOutputBytes(task, totalBytes)
+ latch.countDown()
+ } catch {
+ case e: Exception =>
+ println("Exception in child thread: " + e + " " + e.getMessage)
+ System.exit(1)
+ }
+ }
+ })
+ }
+ latch.await()
+ val end = System.currentTimeMillis()
+ val time = (end - start) / 1000.0
+ val bytesPerSecond = totalBytes.get() / time
+ val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong
+
+ System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits))
+ System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile)))
+ System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong)))
+
+ executor.shutdown()
+ sc.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 2bb7715696..1720007e4e 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -23,20 +23,24 @@ import org.apache.spark.util.Utils
private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
- blocks: Map[String, BlockStatus]) {
+ blocks: Map[BlockId, BlockStatus]) {
- def memUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
- def diskUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def memRemaining : Long = maxMem - memUsed()
+ def rddBlocks = blocks.flatMap {
+ case (rdd: RDDBlockId, status) => Some(rdd, status)
+ case _ => None
+ }
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
@@ -60,7 +64,7 @@ object StorageUtils {
/* Returns RDD-level information, compiled from a list of StorageStatus objects */
def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc)
}
/* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */
@@ -71,26 +75,21 @@ object StorageUtils {
}
/* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
- k.substring(0,k.lastIndexOf('_'))
- }.mapValues(_.values.toArray)
+ val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) =>
+ val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
- // Find the id of the RDD, e.g. rdd_1 => 1
- val rddId = rddKey.split("_").last.toInt
-
// Get the friendly name and storage level for the RDD, if available
sc.persistentRdds.get(rddId).map { r =>
- val rddName = Option(r.name).getOrElse(rddKey)
+ val rddName = Option(r.name).getOrElse(rddId.toString)
val rddStorageLevel = r.getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
}
@@ -101,16 +100,14 @@ object StorageUtils {
rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
- prefix: String) : Array[StorageStatus] = {
+ /* Filters storage status by a given RDD id. */
+ def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int)
+ : Array[StorageStatus] = {
storageStatusList.map { status =>
- val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
+ val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus]
//val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
}
-
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index f2ae8dd97d..860e680576 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -36,11 +36,11 @@ private[spark] object ThreadingTest {
val numBlocksPerProducer = 20000
private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
- val queue = new ArrayBlockingQueue[(String, Seq[Int])](100)
+ val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100)
override def run() {
for (i <- 1 to numBlocksPerProducer) {
- val blockId = "b-" + id + "-" + i
+ val blockId = TestBlockId("b-" + id + "-" + i)
val blockSize = Random.nextInt(1000)
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
@@ -64,7 +64,7 @@ private[spark] object ThreadingTest {
private[spark] class ConsumerThread(
manager: BlockManager,
- queue: ArrayBlockingQueue[(String, Seq[Int])]
+ queue: ArrayBlockingQueue[(BlockId, Seq[Int])]
) extends Thread {
var numBlockConsumed = 0
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index 453394dfda..fcd1b518d0 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -35,7 +35,7 @@ private[spark] object UIWorkloadGenerator {
def main(args: Array[String]) {
if (args.length < 2) {
- println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ println("usage: ./spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
System.exit(1)
}
val master = args(0)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
index b39c0e9769..ca5a28625b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()
var activeTime = 0L
- for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
+ for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index eb3b4e8522..6b854740d6 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
- val stageToPool = new HashMap[Stage, String]()
- val stageToDescription = new HashMap[Stage, String]()
- val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
+ val stageIdToPool = new HashMap[Int, String]()
+ val stageIdToDescription = new HashMap[Int, String]()
+ val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()
- val activeStages = HashSet[Stage]()
- val completedStages = ListBuffer[Stage]()
- val failedStages = ListBuffer[Stage]()
+ val activeStages = HashSet[StageInfo]()
+ val completedStages = ListBuffer[StageInfo]()
+ val failedStages = ListBuffer[StageInfo]()
// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
- val stageToTime = HashMap[Int, Long]()
- val stageToShuffleRead = HashMap[Int, Long]()
- val stageToShuffleWrite = HashMap[Int, Long]()
- val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
- val stageToTasksComplete = HashMap[Int, Int]()
- val stageToTasksFailed = HashMap[Int, Int]()
- val stageToTaskInfos =
+ val stageIdToTime = HashMap[Int, Long]()
+ val stageIdToShuffleRead = HashMap[Int, Long]()
+ val stageIdToShuffleWrite = HashMap[Int, Long]()
+ val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
+ val stageIdToTasksComplete = HashMap[Int, Int]()
+ val stageIdToTasksFailed = HashMap[Int, Int]()
+ val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
- val stage = stageCompleted.stageInfo.stage
- poolToActiveStages(stageToPool(stage)) -= stage
+ val stage = stageCompleted.stage
+ poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}
/** If stages is too large, remove and garbage collect old stages */
- def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
+ def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
- stageToTaskInfos.remove(s.id)
- stageToTime.remove(s.id)
- stageToShuffleRead.remove(s.id)
- stageToShuffleWrite.remove(s.id)
- stageToTasksActive.remove(s.id)
- stageToTasksComplete.remove(s.id)
- stageToTasksFailed.remove(s.id)
- stageToPool.remove(s)
- if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
+ stageIdToTaskInfos.remove(s.stageId)
+ stageIdToTime.remove(s.stageId)
+ stageIdToShuffleRead.remove(s.stageId)
+ stageIdToShuffleWrite.remove(s.stageId)
+ stageIdToTasksActive.remove(s.stageId)
+ stageIdToTasksComplete.remove(s.stageId)
+ stageIdToTasksFailed.remove(s.stageId)
+ stageIdToPool.remove(s.stageId)
+ if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
@@ -95,63 +95,69 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
- stageToPool(stage) = poolName
+ stageIdToPool(stage.stageId) = poolName
val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
- description.map(d => stageToDescription(stage) = d)
+ description.map(d => stageIdToDescription(stage.stageId) = d)
- val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
+ val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
-
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
+ = synchronized {
+ // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
+ // stageToTaskInfos already has the updated status.
+ }
+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
- stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
+ stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
- stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
+ stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
- stageToTime.getOrElseUpdate(sid, 0L)
+ stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
- stageToTime(sid) += time
+ stageIdToTime(sid) += time
totalTime += time
- stageToShuffleRead.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
- stageToShuffleRead(sid) += shuffleRead
+ stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead
- stageToShuffleWrite.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
- stageToShuffleWrite(sid) += shuffleWrite
+ stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@@ -159,10 +165,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
- activeStages -= stage
- poolToActiveStages(stageToPool(stage)) -= stage
- failedStages += stage
- trimIfNecessary(failedStages)
+ /* If two jobs share a stage we could get this failure message twice. So we first
+ * check whether we've already retired this stage. */
+ val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
+ stageInfo.foreach {s =>
+ activeStages -= s
+ poolToActiveStages(stageIdToPool(stage.id)) -= s
+ failedStages += s
+ trimIfNecessary(failedStages)
+ }
case _ =>
}
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 06810d8dbc..cfeeccda41 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
-import org.apache.spark.scheduler.{Schedulable, Stage}
+import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
- var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
+ var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages
def toNodeSeq(): Seq[Node] = {
listener.synchronized {
@@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
}
}
- private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
+ private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]
): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
@@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
</table>
}
- private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
+ private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]])
: Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 163a3746ea..35b5d5fd59 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val stageId = request.getParameter("id").toInt
val now = System.currentTimeMillis()
- if (!listener.stageToTaskInfos.contains(stageId)) {
+ if (!listener.stageIdToTaskInfos.contains(stageId)) {
val content =
<div>
<h4>Summary Metrics</h4> No tasks have started yet
@@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) {
return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
}
- val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
+ val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val numCompleted = tasks.count(_._1.finished)
- val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
+ val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
- val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
+ val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
var activeTime = 0L
- listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
val summary =
<div>
<ul class="unstyled">
<li>
<strong>CPU time: </strong>
- {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
+ {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
<li>
@@ -83,10 +83,10 @@ private[spark] class StagePage(parent: JobProgressUI) {
</div>
val taskHeaders: Seq[String] =
- Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++
- Seq("GC Time") ++
+ Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
+ Seq("Duration", "GC Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
- {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
+ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
Seq("Errors")
val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
@@ -153,6 +153,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
<tr>
+ <td>{info.index}</td>
<td>{info.taskId}</td>
<td>{info.status}</td>
<td>{info.taskLocality}</td>
@@ -169,6 +170,8 @@ private[spark] class StagePage(parent: JobProgressUI) {
Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
}}
{if (shuffleWrite) {
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
<td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
}}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 07db8622da..d7d0441c38 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -22,13 +22,13 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
-import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
+import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
-private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
+private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) {
val listener = parent.listener
val dateFmt = parent.dateFmt
@@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
}
- private def stageRow(s: Stage): Seq[Node] = {
+ private def stageRow(s: StageInfo): Seq[Node] = {
val submissionTime = s.submissionTime match {
case Some(t) => dateFmt.format(new Date(t))
case None => "Unknown"
}
- val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
+ val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
+ val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
- val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
- val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
+ val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
+ val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
+ val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
- val totalTasks = s.numPartitions
+ val totalTasks = s.numTasks
- val poolName = listener.stageToPool.get(s)
+ val poolName = listener.stageIdToPool.get(s.stageId)
val nameLink =
- <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a>
- val description = listener.stageToDescription.get(s)
+ <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.name}</a>
+ val description = listener.stageIdToDescription.get(s.stageId)
.map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
val duration = s.submissionTime.map(t => finishTime - t)
<tr>
- <td>{s.id}</td>
+ <td>{s.stageId}</td>
{if (isFairScheduler) {
<td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}>
{poolName.get}</a></td>}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 43c1257677..b83cd54f3c 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.storage.{StorageStatus, StorageUtils}
+import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils}
import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus
import org.apache.spark.ui.UIUtils._
import org.apache.spark.ui.Page._
@@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
val sc = parent.sc
def render(request: HttpServletRequest): Seq[Node] = {
- val id = request.getParameter("id")
- val prefix = "rdd_" + id.toString
+ val id = request.getParameter("id").toInt
val storageStatusList = sc.getExecutorStorageStatus
- val filteredStorageStatusList = StorageUtils.
- filterStorageStatusByPrefix(storageStatusList, prefix)
+ val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage")
- val workers = filteredStorageStatusList.map((prefix, _))
+ val workers = filteredStorageStatusList.map((id, _))
val workerTable = listingTable(workerHeaders, workerRow, workers)
val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk",
"Executors")
- val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1)
+ val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.
+ sortWith(_._1.name < _._1.name)
val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList)
val blocks = blockStatuses.map {
case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN")))
@@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage)
}
- def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = {
+ def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = {
val (id, block, locations) = row
<tr>
<td>{id}</td>
@@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
</tr>
}
- def workerRow(worker: (String, StorageStatus)): Seq[Node] = {
- val (prefix, status) = worker
+ def workerRow(worker: (Int, StorageStatus)): Seq[Node] = {
+ val (rddId, status) = worker
<tr>
<td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td>
<td>
- {Utils.bytesToString(status.memUsed(prefix))}
+ {Utils.bytesToString(status.memUsedByRDD(rddId))}
({Utils.bytesToString(status.memRemaining)} Remaining)
</td>
- <td>{Utils.bytesToString(status.diskUsed(prefix))}</td>
+ <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
new file mode 100644
index 0000000000..f60deafc6f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * A simple open hash table optimized for the append-only use case, where keys
+ * are never removed, but the value for each key may be changed.
+ *
+ * This implementation uses quadratic probing with a power-of-2 hash table
+ * size, which is guaranteed to explore all spaces for each key (see
+ * http://en.wikipedia.org/wiki/Quadratic_probing).
+ *
+ * TODO: Cache the hash values of each key? java.util.HashMap does that.
+ */
+private[spark]
+class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+
+ private var capacity = nextPowerOf2(initialCapacity)
+ private var mask = capacity - 1
+ private var curSize = 0
+
+ // Holds keys and values in the same array for memory locality; specifically, the order of
+ // elements is key0, value0, key1, value1, key2, value2, etc.
+ private var data = new Array[AnyRef](2 * capacity)
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ private val LOAD_FACTOR = 0.7
+
+ /** Get the value for a given key */
+ def apply(key: K): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ return nullValue
+ }
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (k.eq(curKey) || k == curKey) {
+ return data(2 * pos + 1).asInstanceOf[V]
+ } else if (curKey.eq(null)) {
+ return null.asInstanceOf[V]
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return null.asInstanceOf[V]
+ }
+
+ /** Set the value for a key */
+ def update(key: K, value: V): Unit = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = value
+ haveNullValue = true
+ return
+ }
+ val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
+ if (isNewEntry) {
+ incrementSize()
+ }
+ }
+
+ /**
+ * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
+ * for key, if any, or null otherwise. Returns the newly updated value.
+ */
+ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = updateFunc(haveNullValue, nullValue)
+ haveNullValue = true
+ return nullValue
+ }
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (k.eq(curKey) || k == curKey) {
+ val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ return newValue
+ } else if (curKey.eq(null)) {
+ val newValue = updateFunc(false, null.asInstanceOf[V])
+ data(2 * pos) = k
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ incrementSize()
+ return newValue
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ null.asInstanceOf[V] // Never reached but needed to keep compiler happy
+ }
+
+ /** Iterator method from Iterable */
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
+ var pos = -1
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def nextValue(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ }
+ pos += 1
+ }
+ null
+ }
+
+ override def hasNext: Boolean = nextValue() != null
+
+ override def next(): (K, V) = {
+ val value = nextValue()
+ if (value == null) {
+ throw new NoSuchElementException("End of iterator")
+ }
+ pos += 1
+ value
+ }
+ }
+
+ override def size: Int = curSize
+
+ /** Increase table size by 1, rehashing if necessary */
+ private def incrementSize() {
+ curSize += 1
+ if (curSize > LOAD_FACTOR * capacity) {
+ growTable()
+ }
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def rehash(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ /**
+ * Put an entry into a table represented by data, returning true if
+ * this increases the size of the table or false otherwise. Assumes
+ * that "data" has at least one empty slot.
+ */
+ private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
+ val mask = (data.length / 2) - 1
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = key
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return true
+ } else if (curKey.eq(key) || curKey == key) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return false
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return false // Never reached but needed to keep compiler happy
+ }
+
+ /** Double the table's size and re-hash everything */
+ private def growTable() {
+ val newCapacity = capacity * 2
+ if (newCapacity >= (1 << 30)) {
+ // We can't make the table this big because we want an array of 2x
+ // that size for our data, but array sizes are at most Int.MaxValue
+ throw new Exception("Can't make capacity bigger than 2^29 elements")
+ }
+ val newData = new Array[AnyRef](2 * newCapacity)
+ var pos = 0
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ putInto(newData, data(2 * pos), data(2 * pos + 1))
+ }
+ pos += 1
+ }
+ data = newData
+ capacity = newCapacity
+ mask = newCapacity - 1
+ }
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index a430a75451..67a7f87a5c 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -17,7 +17,6 @@
package org.apache.spark.util
-import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
import java.util.{TimerTask, Timer}
import org.apache.spark.Logging
@@ -25,11 +24,14 @@ import org.apache.spark.Logging
/**
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
-class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging {
+ val name = cleanerType.toString
+
private val delaySeconds = MetadataCleaner.getDelaySeconds
private val periodSeconds = math.max(10, delaySeconds / 10)
private val timer = new Timer(name + " cleanup timer", true)
+
private val task = new TimerTask {
override def run() {
try {
@@ -53,9 +55,38 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
}
}
+object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
+ "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
+
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
+
+ type MetadataCleanerType = Value
+
+ def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = "spark.cleaner.ttl." + which.toString
+}
object MetadataCleaner {
+
+ // using only sys props for now : so that workers can also get to it while preserving earlier behavior.
def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt
- def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) }
+
+ def getDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
+ System.getProperty(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds.toString).toInt
+ }
+
+ def setDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) {
+ System.setProperty(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
+ }
+
+ def setDelaySeconds(delay: Int, resetAll: Boolean = true) {
+ // override for all ?
+ System.setProperty("spark.cleaner.ttl", delay.toString)
+ if (resetAll) {
+ for (cleanerType <- MetadataCleanerType.values) {
+ System.clearProperty(MetadataCleanerType.systemProperty(cleanerType))
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 94ce50e964..7557ddab19 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,16 +18,12 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address}
import java.util.{Locale, Random, UUID}
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
-import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
-import java.util.regex.Pattern
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.io.Source
@@ -43,7 +39,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
-import org.apache.spark.{SparkEnv, SparkException, Logging}
+import org.apache.spark.{SparkException, Logging}
/**
@@ -155,7 +151,7 @@ private[spark] object Utils extends Logging {
return buf
}
- private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+ private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
@@ -287,9 +283,8 @@ private[spark] object Utils extends Logging {
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val env = SparkEnv.get
val uri = new URI(url)
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@@ -454,14 +449,17 @@ private[spark] object Utils extends Logging {
hostPortParseResults.get(hostPort)
}
- private[spark] val daemonThreadFactory: ThreadFactory =
- new ThreadFactoryBuilder().setDaemon(true).build()
+ private val daemonThreadFactoryBuilder: ThreadFactoryBuilder =
+ new ThreadFactoryBuilder().setDaemon(true)
/**
- * Wrapper over newCachedThreadPool.
+ * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor =
- Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
@@ -472,10 +470,13 @@ private[spark] object Utils extends Logging {
}
/**
- * Wrapper over newFixedThreadPool.
+ * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
- Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
private def listFilesSafely(file: File): Seq[File] = {
val files = file.listFiles()
@@ -820,4 +821,10 @@ private[spark] object Utils extends Logging {
// Nothing else to guard against ?
hashAbs
}
+
+ /** Returns a copy of the system properties that is thread-safe to iterator over. */
+ def getSystemProperties(): Map[String, String] = {
+ return System.getProperties().clone()
+ .asInstanceOf[java.util.Properties].toMap[String, String]
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
new file mode 100644
index 0000000000..a1a452315d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A simple, fixed-size bit set implementation. This implementation is fast because it avoids
+ * safety/bound checking.
+ */
+class BitSet(numBits: Int) {
+
+ private[this] val words = new Array[Long](bit2words(numBits))
+ private[this] val numWords = words.length
+
+ /**
+ * Sets the bit at the specified index to true.
+ * @param index the bit index
+ */
+ def set(index: Int) {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ words(index >> 6) |= bitmask // div by 64 and mask
+ }
+
+ /**
+ * Return the value of the bit with the specified index. The value is true if the bit with
+ * the index is currently set in this BitSet; otherwise, the result is false.
+ *
+ * @param index the bit index
+ * @return the value of the bit with the specified index
+ */
+ def get(index: Int): Boolean = {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ (words(index >> 6) & bitmask) != 0 // div by 64 and mask
+ }
+
+ /** Return the number of bits set to true in this BitSet. */
+ def cardinality(): Int = {
+ var sum = 0
+ var i = 0
+ while (i < numWords) {
+ sum += java.lang.Long.bitCount(words(i))
+ i += 1
+ }
+ sum
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then -1 is returned.
+ *
+ * To iterate over the true bits in a BitSet, use the following loop:
+ *
+ * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
+ * // operate on index i here
+ * }
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ def nextSetBit(fromIndex: Int): Int = {
+ var wordIndex = fromIndex >> 6
+ if (wordIndex >= numWords) {
+ return -1
+ }
+
+ // Try to find the next set bit in the current word
+ val subIndex = fromIndex & 0x3f
+ var word = words(wordIndex) >> subIndex
+ if (word != 0) {
+ return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
+ }
+
+ // Find the next set bit in the rest of the words
+ wordIndex += 1
+ while (wordIndex < numWords) {
+ word = words(wordIndex)
+ if (word != 0) {
+ return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
+ }
+ wordIndex += 1
+ }
+
+ -1
+ }
+
+ /** Return the number of longs it would take to hold numBits. */
+ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
new file mode 100644
index 0000000000..45849b3380
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+
+
+/**
+ * A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
+ * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
+ * space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ protected var _keySet = new OpenHashSet[K](initialCapacity)
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ private var _values: Array[V] = _
+ _values = new Array[V](_keySet.capacity)
+
+ @transient private var _oldValues: Array[V] = null
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ if (k == null) {
+ nullValue
+ } else {
+ val pos = _keySet.getPos(k)
+ if (pos < 0) {
+ null.asInstanceOf[V]
+ } else {
+ _values(pos)
+ }
+ }
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ if (k == null) {
+ haveNullValue = true
+ nullValue = v
+ } else {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ if (k == null) {
+ if (haveNullValue) {
+ nullValue = mergeValue(nullValue)
+ } else {
+ haveNullValue = true
+ nullValue = defaultValue
+ }
+ nullValue
+ } else {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = -1
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ pos += 1
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
new file mode 100644
index 0000000000..49d95afdb9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect._
+
+/**
+ * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
+ * removed.
+ *
+ * The underlying implementation uses Scala compiler's specialization to generate optimized
+ * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet
+ * while incurring much less memory overhead. This can serve as building blocks for higher level
+ * data structures such as an optimized HashMap.
+ *
+ * This OpenHashSet is designed to serve as building blocks for higher level data structures
+ * such as an optimized hash map. Compared with standard hash set implementations, this class
+ * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to
+ * retrieve the position of a key in the underlying array.
+ *
+ * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
+ * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
+ */
+private[spark]
+class OpenHashSet[@specialized(Long, Int) T: ClassTag](
+ initialCapacity: Int,
+ loadFactor: Double)
+ extends Serializable {
+
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+ require(loadFactor < 1.0, "Load factor must be less than 1.0")
+ require(loadFactor > 0.0, "Load factor must be greater than 0.0")
+
+ import OpenHashSet._
+
+ def this(initialCapacity: Int) = this(initialCapacity, 0.7)
+
+ def this() = this(64)
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+
+ protected val hasher: Hasher[T] = {
+ // It would've been more natural to write the following using pattern matching. But Scala 2.9.x
+ // compiler has a bug when specialization is used together with this pattern matching, and
+ // throws:
+ // scala.tools.nsc.symtab.Types$TypeError: type mismatch;
+ // found : scala.reflect.AnyValManifest[Long]
+ // required: scala.reflect.ClassTag[Int]
+ // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
+ // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
+ // ...
+ val mt = classTag[T]
+ if (mt == ClassTag.Long) {
+ (new LongHasher).asInstanceOf[Hasher[T]]
+ } else if (mt == ClassTag.Int) {
+ (new IntHasher).asInstanceOf[Hasher[T]]
+ } else {
+ new Hasher[T]
+ }
+ }
+
+ protected var _capacity = nextPowerOf2(initialCapacity)
+ protected var _mask = _capacity - 1
+ protected var _size = 0
+
+ protected var _bitset = new BitSet(_capacity)
+
+ // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
+ // specialization bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _data: Array[T] = _
+ _data = new Array[T](_capacity)
+
+ /** Number of elements in the set. */
+ def size: Int = _size
+
+ /** The capacity of the set (i.e. size of the underlying array). */
+ def capacity: Int = _capacity
+
+ /** Return true if this set contains the specified element. */
+ def contains(k: T): Boolean = getPos(k) != INVALID_POS
+
+ /**
+ * Add an element to the set. If the set is over capacity after the insertion, grow the set
+ * and rehash all elements.
+ */
+ def add(k: T) {
+ addWithoutResize(k)
+ rehashIfNeeded(k, grow, move)
+ }
+
+ /**
+ * Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
+ * The caller is responsible for calling rehashIfNeeded.
+ *
+ * Use (retval & POSITION_MASK) to get the actual position, and
+ * (retval & EXISTENCE_MASK) != 0 for prior existence.
+ *
+ * @return The position where the key is placed, plus the highest order bit is set if the key
+ * exists previously.
+ */
+ def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+
+ /**
+ * Rehash the set if it is overloaded.
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ if (_size > loadFactor * _capacity) {
+ rehash(k, allocateFunc, moveFunc)
+ }
+ }
+
+ /**
+ * Return the position of the element in the underlying array, or INVALID_POS if it is not found.
+ */
+ def getPos(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ return INVALID_POS
+ } else if (k == _data(pos)) {
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ INVALID_POS
+ }
+
+ /** Return the value at the specified position. */
+ def getValue(pos: Int): T = _data(pos)
+
+ /**
+ * Return the next position with an element stored, starting from the given position inclusively.
+ */
+ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
+
+ /**
+ * Put an entry into the set. Return the position where the key is placed. In addition, the
+ * highest bit in the returned position is set if the key exists prior to this put.
+ *
+ * This function assumes the data array has at least one empty slot.
+ */
+ private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
+ val mask = data.length - 1
+ var pos = hashcode(hasher.hash(k)) & mask
+ var i = 1
+ while (true) {
+ if (!bitset.get(pos)) {
+ // This is a new key.
+ data(pos) = k
+ bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
+
+ /**
+ * Double the table's size and re-hash everything. We are not really using k, but it is declared
+ * so Scala compiler can specialize this method (which leads to calling the specialized version
+ * of putInto).
+ *
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ val newCapacity = _capacity * 2
+ require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+
+ allocateFunc(newCapacity)
+ val newData = new Array[T](newCapacity)
+ val newBitset = new BitSet(newCapacity)
+ var pos = 0
+ _size = 0
+ while (pos < _capacity) {
+ if (_bitset.get(pos)) {
+ val newPos = putInto(newBitset, newData, _data(pos))
+ moveFunc(pos, newPos & POSITION_MASK)
+ }
+ pos += 1
+ }
+ _bitset = newBitset
+ _data = newData
+ _capacity = newCapacity
+ _mask = newCapacity - 1
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def hashcode(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
+
+
+private[spark]
+object OpenHashSet {
+
+ val INVALID_POS = -1
+ val NONEXISTENCE_MASK = 0x80000000
+ val POSITION_MASK = 0xEFFFFFF
+
+ /**
+ * A set of specialized hash function implementation to avoid boxing hash code computation
+ * in the specialized implementation of OpenHashSet.
+ */
+ sealed class Hasher[@specialized(Long, Int) T] {
+ def hash(o: T): Int = o.hashCode()
+ }
+
+ class LongHasher extends Hasher[Long] {
+ override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt
+ }
+
+ class IntHasher extends Hasher[Int] {
+ override def hash(o: Int): Int = o
+ }
+
+ private def grow1(newSize: Int) {}
+ private def move1(oldPos: Int, newPos: Int) { }
+
+ private val grow = grow1 _
+ private val move = move1 _
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..2e1ef06cbc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect._
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+ @specialized(Long, Int, Double) V: ClassTag](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _keySet: OpenHashSet[K] = _
+ private var _values: Array[V] = _
+ _keySet = new OpenHashSet[K](initialCapacity)
+ _values = new Array[V](_keySet.capacity)
+
+ private var _oldValues: Array[V] = null
+
+ override def size = _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ val pos = _keySet.getPos(k)
+ _values(pos)
+ }
+
+ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+ def getOrElse(k: K, elseValue: V): V = {
+ val pos = _keySet.getPos(k)
+ if (pos >= 0) _values(pos) else elseValue
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = 0
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the unspecialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
new file mode 100644
index 0000000000..465c221d5f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+
+/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */
+private[spark]
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) {
+ private var numElements = 0
+ private var array: Array[V] = _
+
+ // NB: This must be separate from the declaration, otherwise the specialized parent class
+ // will get its own array with the same initial size. TODO: Figure out why...
+ array = new Array[V](initialSize)
+
+ def apply(index: Int): V = {
+ require(index < numElements)
+ array(index)
+ }
+
+ def +=(value: V) {
+ if (numElements == array.length) { resize(array.length * 2) }
+ array(numElements) = value
+ numElements += 1
+ }
+
+ def length = numElements
+
+ def getUnderlyingArray = array
+
+ /** Resizes the array, dropping elements if the total length decreases. */
+ def resize(newLength: Int) {
+ val newArray = new Array[V](newLength)
+ array.copyToArray(newArray)
+ array = newArray
+ }
+}