aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml16
-rw-r--r--core/src/main/scala/spark/CacheManager.scala4
-rw-r--r--core/src/main/scala/spark/DoubleRDDFunctions.scala4
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala2
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala66
-rw-r--r--core/src/main/scala/spark/Partition.scala (renamed from core/src/main/scala/spark/Split.scala)2
-rw-r--r--core/src/main/scala/spark/RDD.scala138
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala12
-rw-r--r--core/src/main/scala/spark/SparkContext.scala99
-rw-r--r--core/src/main/scala/spark/Utils.scala8
-rw-r--r--core/src/main/scala/spark/api/java/JavaDoubleRDD.scala7
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala47
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala7
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala14
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala22
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala2
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala60
-rw-r--r--core/src/main/scala/spark/deploy/ApplicationDescription.scala (renamed from core/src/main/scala/spark/deploy/JobDescription.scala)4
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala29
-rw-r--r--core/src/main/scala/spark/deploy/JsonProtocol.scala18
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala38
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala37
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala6
-rw-r--r--core/src/main/scala/spark/deploy/master/ApplicationInfo.scala (renamed from core/src/main/scala/spark/deploy/master/JobInfo.scala)10
-rw-r--r--core/src/main/scala/spark/deploy/master/ApplicationState.scala11
-rw-r--r--core/src/main/scala/spark/deploy/master/ExecutorInfo.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/JobState.scala9
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala221
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala34
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala6
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala29
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala92
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala2
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala4
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala30
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala35
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala6
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala24
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala37
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala22
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala59
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala33
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala25
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala8
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala (renamed from core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala)12
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala24
-rw-r--r--core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala (renamed from core/src/main/scala/spark/ParallelCollection.scala)31
-rw-r--r--core/src/main/scala/spark/rdd/PartitionPruningRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala22
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala34
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala36
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala295
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala13
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala40
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala19
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala20
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala18
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala5
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala6
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerUI.scala39
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala24
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala7
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala10
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_details.scala.html40
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_row.scala.html20
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_table.scala.html (renamed from core/src/main/twirl/spark/deploy/master/job_table.scala.html)8
-rw-r--r--core/src/main/twirl/spark/deploy/master/executor_row.scala.html6
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html22
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_details.scala.html40
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_row.scala.html20
-rw-r--r--core/src/main/twirl/spark/deploy/worker/executor_row.scala.html10
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html6
-rw-r--r--core/src/main/twirl/spark/storage/rdd.scala.html6
-rw-r--r--core/src/main/twirl/spark/storage/rdd_table.scala.html6
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala116
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala3
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java24
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala3
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala34
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala25
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala10
-rw-r--r--core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala (renamed from core/src/test/scala/spark/ParallelCollectionSplitSuite.scala)40
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala663
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala10
101 files changed, 2070 insertions, 1174 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 862d3ec37a..66c62151fe 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>
<scope>test</scope>
@@ -163,11 +168,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -220,12 +220,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala
index 711435c333..c7b379a3fb 100644
--- a/core/src/main/scala/spark/CacheManager.scala
+++ b/core/src/main/scala/spark/CacheManager.scala
@@ -11,13 +11,13 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
- def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
- // Split is in cache, so just return its values
+ // Partition is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala
index b2a0e2b631..178d31a73b 100644
--- a/core/src/main/scala/spark/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala
@@ -42,14 +42,14 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/** (Experimental) Approximate operation to return the mean within a timeout. */
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
- val evaluator = new MeanEvaluator(self.splits.size, confidence)
+ val evaluator = new MeanEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
- val evaluator = new SumEvaluator(self.splits.size, confidence)
+ val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
}
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index aaf433b324..4735207585 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
}
}
- def cleanup(cleanupTime: Long) {
+ private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 231e23a7de..4319cbd892 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -62,7 +62,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
- if (mapSideCombine) {
+ if (self.partitioner == Some(partitioner)) {
+ self.mapPartitions(aggregator.combineValuesByKey(_), true)
+ } else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
@@ -81,8 +83,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
- numSplits: Int): RDD[(K, C)] = {
- combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
+ numPartitions: Int): RDD[(K, C)] = {
+ combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
}
/**
@@ -143,10 +145,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
+ * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
*/
- def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
- reduceByKey(new HashPartitioner(numSplits), func)
+ def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = {
+ reduceByKey(new HashPartitioner(numPartitions), func)
}
/**
@@ -164,10 +166,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with into `numSplits` partitions.
+ * resulting RDD with into `numPartitions` partitions.
*/
- def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
- groupByKey(new HashPartitioner(numSplits))
+ def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = {
+ groupByKey(new HashPartitioner(numPartitions))
}
/**
@@ -285,8 +287,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
- def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = {
- join(other, new HashPartitioner(numSplits))
+ def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = {
+ join(other, new HashPartitioner(numPartitions))
}
/**
@@ -303,10 +305,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
- * into `numSplits` partitions.
+ * into `numPartitions` partitions.
*/
- def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = {
- leftOuterJoin(other, new HashPartitioner(numSplits))
+ def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = {
+ leftOuterJoin(other, new HashPartitioner(numPartitions))
}
/**
@@ -325,8 +327,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD into the given number of partitions.
*/
- def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = {
- rightOuterJoin(other, new HashPartitioner(numSplits))
+ def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = {
+ rightOuterJoin(other, new HashPartitioner(numPartitions))
}
/**
@@ -361,7 +363,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](
- Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
+ Seq(self.asInstanceOf[RDD[(K, _)]], other.asInstanceOf[RDD[(K, _)]]),
partitioner)
val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
prfs.mapValues {
@@ -380,9 +382,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](
- Seq(self.asInstanceOf[RDD[(_, _)]],
- other1.asInstanceOf[RDD[(_, _)]],
- other2.asInstanceOf[RDD[(_, _)]]),
+ Seq(self.asInstanceOf[RDD[(K, _)]],
+ other1.asInstanceOf[RDD[(K, _)]],
+ other2.asInstanceOf[RDD[(K, _)]]),
partitioner)
val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
prfs.mapValues {
@@ -412,17 +414,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Seq[V], Seq[W]))] = {
- cogroup(other, new HashPartitioner(numSplits))
+ def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, new HashPartitioner(numPartitions))
}
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
- def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numSplits: Int)
+ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
- cogroup(other1, other2, new HashPartitioner(numSplits))
+ cogroup(other1, other2, new HashPartitioner(numPartitions))
}
/** Alias for cogroup. */
@@ -465,7 +467,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- self.filter(_._1 == key).map(_._2).collect
+ self.filter(_._1 == key).map(_._2).collect()
}
}
@@ -590,7 +592,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
var count = 0
while(iter.hasNext) {
- val record = iter.next
+ val record = iter.next()
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
@@ -634,9 +636,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
- def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
+ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[(K,V)] = {
val shuffled =
- new ShuffledRDD[K, V](self, new RangePartitioner(numSplits, self, ascending))
+ new ShuffledRDD[K, V](self, new RangePartitioner(numPartitions, self, ascending))
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
@@ -650,9 +652,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
private[spark]
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
- override def getSplits = firstParent[(K, V)].splits
+ override def getPartitions = firstParent[(K, V)].partitions
override val partitioner = firstParent[(K, V)].partitioner
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
}
@@ -660,9 +662,9 @@ private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
extends RDD[(K, U)](prev) {
- override def getSplits = firstParent[(K, V)].splits
+ override def getPartitions = firstParent[(K, V)].partitions
override val partitioner = firstParent[(K, V)].partitioner
- override def compute(split: Split, context: TaskContext) = {
+ override def compute(split: Partition, context: TaskContext) = {
firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}
diff --git a/core/src/main/scala/spark/Split.scala b/core/src/main/scala/spark/Partition.scala
index 90d4b47c55..e384308ef6 100644
--- a/core/src/main/scala/spark/Split.scala
+++ b/core/src/main/scala/spark/Partition.scala
@@ -3,7 +3,7 @@ package spark
/**
* A partition of an RDD.
*/
-trait Split extends Serializable {
+trait Partition extends Serializable {
/**
* Get the split's index within its parent RDD
*/
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 210404d540..da82dfd10f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -20,13 +20,14 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
+import spark.rdd.CoalescedRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
import spark.rdd.GlommedRDD
import spark.rdd.MappedRDD
import spark.rdd.MapPartitionsRDD
-import spark.rdd.MapPartitionsWithSplitRDD
+import spark.rdd.MapPartitionsWithIndexRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.UnionRDD
@@ -48,7 +49,7 @@ import SparkContext._
*
* Internally, each RDD is characterized by five main properties:
*
- * - A list of splits (partitions)
+ * - A list of partitions
* - A function for computing each split
* - A list of dependencies on other RDDs
* - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
@@ -75,13 +76,13 @@ abstract class RDD[T: ClassManifest](
// =======================================================================
/** Implemented by subclasses to compute a given partition. */
- def compute(split: Split, context: TaskContext): Iterator[T]
+ def compute(split: Partition, context: TaskContext): Iterator[T]
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
- protected def getSplits: Array[Split]
+ protected def getPartitions: Array[Partition]
/**
* Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
@@ -90,7 +91,7 @@ abstract class RDD[T: ClassManifest](
protected def getDependencies: Seq[Dependency[_]] = deps
/** Optionally overridden by subclasses to specify placement preferences. */
- protected def getPreferredLocations(split: Split): Seq[String] = Nil
+ protected def getPreferredLocations(split: Partition): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
@@ -136,10 +137,10 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- // Our dependencies and splits will be gotten by calling subclass's methods below, and will
+ // Our dependencies and partitions will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = null
- @transient private var splits_ : Array[Split] = null
+ @transient private var partitions_ : Array[Partition] = null
/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
@@ -158,15 +159,15 @@ abstract class RDD[T: ClassManifest](
}
/**
- * Get the array of splits of this RDD, taking into account whether the
+ * Get the array of partitions of this RDD, taking into account whether the
* RDD is checkpointed or not.
*/
- final def splits: Array[Split] = {
- checkpointRDD.map(_.splits).getOrElse {
- if (splits_ == null) {
- splits_ = getSplits
+ final def partitions: Array[Partition] = {
+ checkpointRDD.map(_.partitions).getOrElse {
+ if (partitions_ == null) {
+ partitions_ = getPartitions
}
- splits_
+ partitions_
}
}
@@ -174,7 +175,7 @@ abstract class RDD[T: ClassManifest](
* Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not.
*/
- final def preferredLocations(split: Split): Seq[String] = {
+ final def preferredLocations(split: Partition): Seq[String] = {
checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
getPreferredLocations(split)
}
@@ -185,7 +186,7 @@ abstract class RDD[T: ClassManifest](
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
- final def iterator(split: Split, context: TaskContext): Iterator[T] = {
+ final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
@@ -196,7 +197,7 @@ abstract class RDD[T: ClassManifest](
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
- private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
+ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = {
if (isCheckpointed) {
firstParent[T].iterator(split, context)
} else {
@@ -226,10 +227,15 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): RDD[T] =
- map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
+ def distinct(numPartitions: Int): RDD[T] =
+ map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1)
- def distinct(): RDD[T] = distinct(splits.size)
+ def distinct(): RDD[T] = distinct(partitions.size)
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): RDD[T] = new CoalescedRDD(this, numPartitions)
/**
* Return a sampled subset of this RDD.
@@ -297,9 +303,9 @@ abstract class RDD[T: ClassManifest](
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = {
+ def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
- this.map(t => (cleanF(t), t)).groupByKey(numSplits)
+ this.map(t => (cleanF(t), t)).groupByKey(numPartitions)
}
/**
@@ -330,14 +336,24 @@ abstract class RDD[T: ClassManifest](
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
- /**
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
+ * of the original partition.
+ */
+ def mapPartitionsWithIndex[U: ClassManifest](
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] =
+ new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+
+ /**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
+ @deprecated("use mapPartitionsWithIndex")
def mapPartitionsWithSplit[U: ClassManifest](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)
+ new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
@@ -378,27 +394,29 @@ abstract class RDD[T: ClassManifest](
}
/**
- * Reduces the elements of this RDD using the specified associative binary operator.
+ * Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
- }else {
+ } else {
None
}
}
- val options = sc.runJob(this, reducePartition)
- val results = new ArrayBuffer[T]
- for (opt <- options; elem <- opt) {
- results += elem
- }
- if (results.size == 0) {
- throw new UnsupportedOperationException("empty collection")
- } else {
- return results.reduceLeft(cleanF)
+ var jobResult: Option[T] = None
+ val mergeResult = (index: Int, taskResult: Option[T]) => {
+ if (taskResult != None) {
+ jobResult = jobResult match {
+ case Some(value) => Some(f(value, taskResult.get))
+ case None => taskResult
+ }
+ }
}
+ sc.runJob(this, reducePartition, mergeResult)
+ // Get the final result out of our Option, or throw an exception if the RDD was empty
+ jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
/**
@@ -408,9 +426,13 @@ abstract class RDD[T: ClassManifest](
* modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanOp = sc.clean(op)
- val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
- return results.fold(zeroValue)(cleanOp)
+ val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
+ val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
+ sc.runJob(this, foldPartition, mergeResult)
+ jobResult
}
/**
@@ -422,11 +444,14 @@ abstract class RDD[T: ClassManifest](
* allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
- val results = sc.runJob(this,
- (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
- return results.fold(zeroValue)(cleanCombOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
+ sc.runJob(this, aggregatePartition, mergeResult)
+ jobResult
}
/**
@@ -437,7 +462,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}).sum
@@ -452,11 +477,11 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}
- val evaluator = new CountEvaluator(splits.size, confidence)
+ val evaluator = new CountEvaluator(partitions.size, confidence)
sc.runApproximateJob(this, countElements, evaluator, timeout)
}
@@ -507,7 +532,7 @@ abstract class RDD[T: ClassManifest](
}
map
}
- val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
+ val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence)
sc.runApproximateJob(this, countPartition, evaluator, timeout)
}
@@ -522,7 +547,7 @@ abstract class RDD[T: ClassManifest](
}
val buf = new ArrayBuffer[T]
var p = 0
- while (buf.size < num && p < splits.size) {
+ while (buf.size < num && p < partitions.size) {
val left = num - buf.size
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
buf ++= res(0)
@@ -621,27 +646,32 @@ abstract class RDD[T: ClassManifest](
/** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
+ // Avoid handling doCheckpoint multiple times to prevent excessive recursion
+ private var doCheckpointCalled = false
+
/**
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/
private[spark] def doCheckpoint() {
- if (checkpointData.isDefined) {
- checkpointData.get.doCheckpoint()
- } else {
- dependencies.foreach(_.rdd.doCheckpoint())
+ if (!doCheckpointCalled) {
+ doCheckpointCalled = true
+ if (checkpointData.isDefined) {
+ checkpointData.get.doCheckpoint()
+ } else {
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
}
}
/**
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
- * created from the checkpoint file, and forget its old dependencies and splits.
+ * created from the checkpoint file, and forget its old dependencies and partitions.
*/
private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
clearDependencies()
- dependencies_ = null
- splits_ = null
+ partitions_ = null
deps = null // Forget the constructor argument for dependencies too
}
@@ -656,15 +686,15 @@ abstract class RDD[T: ClassManifest](
}
/** A description of this RDD and its recursive dependencies for debugging. */
- def toDebugString(): String = {
+ def toDebugString: String = {
def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = {
- Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++
+ Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++
rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " "))
}
debugString(this).mkString("\n")
}
- override def toString(): String = "%s%s[%d] at %s".format(
+ override def toString: String = "%s%s[%d] at %s".format(
Option(name).map(_ + " ").getOrElse(""),
getClass.getSimpleName,
id,
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index a4a4ebaf53..d00092e984 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -16,7 +16,7 @@ private[spark] object CheckpointState extends Enumeration {
/**
* This class contains all the information related to RDD checkpointing. Each instance of this class
* is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
- * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations
+ * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
* of the checkpointed RDD.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
@@ -67,11 +67,11 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
val newRDD = new CheckpointRDD[T](rdd.context, path)
- // Change the dependencies and splits of the RDD
+ // Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpRDD = Some(newRDD)
- rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
+ rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
@@ -79,15 +79,15 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
}
// Get preferred location of a split after checkpointing
- def getPreferredLocations(split: Split): Seq[String] = {
+ def getPreferredLocations(split: Partition): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
}
- def getSplits: Array[Split] = {
+ def getPartitions: Array[Partition] = {
RDDCheckpointData.synchronized {
- cpRDD.get.splits
+ cpRDD.get.partitions
}
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index b0d4b58240..d39767c3b3 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -39,20 +39,21 @@ import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD}
+import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import storage.BlockManagerUI
import util.{MetadataCleaner, TimeStampedHashMap}
+import storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI.
+ * @param appName A name for your application, to display on the cluster web UI.
* @param sparkHome Location where Spark is installed on cluster nodes.
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
@@ -60,7 +61,7 @@ import util.{MetadataCleaner, TimeStampedHashMap}
*/
class SparkContext(
val master: String,
- val jobName: String,
+ val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
environment: Map[String, String] = Map())
@@ -107,8 +108,9 @@ class SparkContext(
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
+ // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
- "SPARK_TESTING")) {
+ "SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
@@ -141,7 +143,7 @@ class SparkContext(
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
scheduler.initialize(backend)
scheduler
@@ -160,7 +162,7 @@ class SparkContext(
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
@@ -176,9 +178,9 @@ class SparkContext(
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
+ new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
} else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
+ new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
}
scheduler.initialize(backend)
scheduler
@@ -187,6 +189,7 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
@@ -213,7 +216,7 @@ class SparkContext(
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
- new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]())
+ new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
@@ -226,7 +229,7 @@ class SparkContext(
* Create a new partition for each collection item. */
def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
- new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs)
+ new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
/**
@@ -467,13 +470,28 @@ class SparkContext(
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
- def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
+ def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
}
}
/**
+ * Return information about what RDDs are cached, if they are in mem or on disk, how much space
+ * they take, etc.
+ */
+ def getRDDStorageInfo : Array[RDDInfo] = {
+ StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
+ }
+
+ /**
+ * Return information about blocks stored in all of the slaves
+ */
+ def getExecutorStorageStatus : Array[StorageStatus] = {
+ env.blockManager.master.getStorageStatus
+ }
+
+ /**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
@@ -543,27 +561,43 @@ class SparkContext(
}
/**
- * Run a function on a given set of partitions in an RDD and return the results. This is the main
- * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
- * whether the scheduler can run the computation on the driver rather than shipping it out to the
- * cluster, for short actions like first().
+ * Run a function on a given set of partitions in an RDD and pass the results to the given
+ * handler function. This is the main entry point for all actions in Spark. The allowLocal
+ * flag specifies whether the scheduler can run the computation on the driver rather than
+ * shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
- allowLocal: Boolean
- ): Array[U] = {
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
+ * Run a function on a given set of partitions in an RDD and return the results as an array. The
+ * allowLocal flag specifies whether the scheduler can run the computation on the driver rather
+ * than shipping it out to the cluster, for short actions like first().
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean
+ ): Array[U] = {
+ val results = new Array[U](partitions.size)
+ runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
+ results
+ }
+
+ /**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
@@ -580,14 +614,37 @@ class SparkContext(
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.splits.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.splits.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.size, false)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: (TaskContext, Iterator[T]) => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
+ runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
}
/**
@@ -639,7 +696,7 @@ class SparkContext(
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
- /** Default min number of splits for Hadoop RDDs when not given by user */
+ /** Default min number of partitions for Hadoop RDDs when not given by user */
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
private var nextShuffleId = new AtomicInteger(0)
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 1e58d01273..28d643abca 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -12,6 +12,7 @@ import scala.io.Source
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import scala.Some
+import spark.serializer.SerializerInstance
/**
* Various utility methods used by Spark.
@@ -446,4 +447,11 @@ private object Utils extends Logging {
socket.close()
portBound
}
+
+ /**
+ * Clone an object using a Spark serializer.
+ */
+ def clone[T](value: T, serializer: SerializerInstance): T = {
+ serializer.deserialize[T](serializer.serialize(value))
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
index 843e1bd18b..da3cb2cd31 100644
--- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
@@ -44,7 +44,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits))
+ def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
@@ -53,6 +53,11 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
fromRDD(srdd.filter(x => f(x).booleanValue()))
/**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD =
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 8ce32e0e2f..df3af3817d 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -54,7 +54,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits))
+ def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
@@ -63,6 +63,11 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
/**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.coalesce(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
@@ -117,8 +122,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
- numSplits: Int): JavaPairRDD[K, C] =
- combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
+ numPartitions: Int): JavaPairRDD[K, C] =
+ combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
/**
* Merge the values for each key using an associative reduce function. This will also perform
@@ -157,10 +162,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
+ * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
*/
- def reduceByKey(func: JFunction2[V, V, V], numSplits: Int): JavaPairRDD[K, V] =
- fromRDD(rdd.reduceByKey(func, numSplits))
+ def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] =
+ fromRDD(rdd.reduceByKey(func, numPartitions))
/**
* Group the values for each key in the RDD into a single sequence. Allows controlling the
@@ -171,10 +176,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with into `numSplits` partitions.
+ * resulting RDD with into `numPartitions` partitions.
*/
- def groupByKey(numSplits: Int): JavaPairRDD[K, JList[V]] =
- fromRDD(groupByResultToJava(rdd.groupByKey(numSplits)))
+ def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions)))
/**
* Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine`
@@ -256,8 +261,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
- def join[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, W)] =
- fromRDD(rdd.join(other, numSplits))
+ def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other, numPartitions))
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
@@ -272,10 +277,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
- * into `numSplits` partitions.
+ * into `numPartitions` partitions.
*/
- def leftOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, Option[W])] =
- fromRDD(rdd.leftOuterJoin(other, numSplits))
+ def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Option[W])] =
+ fromRDD(rdd.leftOuterJoin(other, numPartitions))
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
@@ -292,8 +297,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD into the given number of partitions.
*/
- def rightOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (Option[V], W)] =
- fromRDD(rdd.rightOuterJoin(other, numSplits))
+ def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Option[V], W)] =
+ fromRDD(rdd.rightOuterJoin(other, numPartitions))
/**
* Return the key-value pairs in this RDD to the master as a Map.
@@ -357,16 +362,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (JList[V], JList[W])]
- = fromRDD(cogroupResultToJava(rdd.cogroup(other, numSplits)))
+ def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])]
+ = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
- def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numSplits: Int)
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
- fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numSplits)))
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
/** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
@@ -447,7 +452,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
*/
def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
- sortByKey(comp, true)
+ sortByKey(comp, ascending)
}
/**
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index ac31350ec3..3ccd6f055e 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -30,7 +30,7 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits))
+ def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
@@ -39,6 +39,11 @@ JavaRDDLike[T, JavaRDD[T]] {
wrapRDD(rdd.filter((x => f(x).booleanValue())))
/**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions)
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 60025b459c..90b45cf875 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -4,7 +4,7 @@ import java.util.{List => JList}
import scala.Tuple2
import scala.collection.JavaConversions._
-import spark.{SparkContext, Split, RDD, TaskContext}
+import spark.{SparkContext, Partition, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import spark.partial.{PartialResult, BoundedDouble}
@@ -20,7 +20,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
def rdd: RDD[T]
/** Set of partitions in this RDD. */
- def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq)
+ def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
/** The [[spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
@@ -36,7 +36,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
- def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] =
+ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] =
asJavaIterator(rdd.iterator(split, taskContext))
// Transformations (return a new RDD)
@@ -146,12 +146,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K], numSplits: Int): JavaPairRDD[K, JList[T]] = {
+ def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
implicit val vcm: ClassManifest[JList[T]] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numSplits)(f.returnType)))(kcm, vcm)
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
}
/**
@@ -201,7 +201,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
}
/**
- * Reduces the elements of this RDD using the specified associative binary operator.
+ * Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
@@ -333,6 +333,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
/** A description of this RDD and its recursive dependencies for debugging. */
def toDebugString(): String = {
- rdd.toDebugString()
+ rdd.toDebugString
}
}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index 50b8970cd8..f75fc27c7b 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -23,41 +23,41 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
*/
- def this(master: String, jobName: String) = this(new SparkContext(master, jobName))
+ def this(master: String, appName: String) = this(new SparkContext(master, appName))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
- def this(master: String, jobName: String, sparkHome: String, jarFile: String) =
- this(new SparkContext(master, jobName, sparkHome, Seq(jarFile)))
+ def this(master: String, appName: String, sparkHome: String, jarFile: String) =
+ this(new SparkContext(master, appName, sparkHome, Seq(jarFile)))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
- def this(master: String, jobName: String, sparkHome: String, jars: Array[String]) =
- this(new SparkContext(master, jobName, sparkHome, jars.toSeq))
+ def this(master: String, appName: String, sparkHome: String, jars: Array[String]) =
+ this(new SparkContext(master, appName, sparkHome, jars.toSeq))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
* @param environment Environment variables to set on worker nodes
*/
- def this(master: String, jobName: String, sparkHome: String, jars: Array[String],
+ def this(master: String, appName: String, sparkHome: String, jars: Array[String],
environment: JMap[String, String]) =
- this(new SparkContext(master, jobName, sparkHome, jars.toSeq, environment))
+ this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment))
private[spark] val env = sc.env
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
index 519e310323..d618c098c2 100644
--- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -9,7 +9,7 @@ import java.util.Arrays
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
- * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * lifetime of the program (i.e. that it is not re-used as the id of a different partitioning
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index f43a152ca7..8c73477384 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -32,11 +32,11 @@ private[spark] class PythonRDD[T: ClassManifest](
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
broadcastVars, accumulator)
- override def getSplits = parent.splits
+ override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
- override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
@@ -65,7 +65,7 @@ private[spark] class PythonRDD[T: ClassManifest](
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
- // Split index
+ // Partition index
dOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
@@ -103,21 +103,27 @@ private[spark] class PythonRDD[T: ClassManifest](
private def read(): Array[Byte] = {
try {
- val length = stream.readInt()
- if (length != -1) {
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
- } else {
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
- val update = new Array[Byte](len2)
- stream.readFully(update)
- accumulator += Collections.singletonList(update)
- }
- new Array[Byte](0)
+ stream.readInt() match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ case -2 =>
+ // Signals that an exception has been thrown in python
+ val exLength = stream.readInt()
+ val obj = new Array[Byte](exLength)
+ stream.readFully(obj)
+ throw new PythonException(new String(obj))
+ case -1 =>
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
}
} catch {
case eof: EOFException => {
@@ -140,14 +146,17 @@ private[spark] class PythonRDD[T: ClassManifest](
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
+/** Thrown for exceptions in user Python code. */
+private class PythonException(msg: String) extends Exception(msg)
+
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Array[Byte], Array[Byte])](prev) {
- override def getSplits = prev.splits
- override def compute(split: Split, context: TaskContext) =
+ override def getPartitions = prev.partitions
+ override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
@@ -229,6 +238,11 @@ private[spark] object PythonRDD {
}
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ import scala.collection.JavaConverters._
+ writeIteratorToPickleFile(items.asScala, filename)
+ }
+
+ def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
@@ -236,8 +250,10 @@ private[spark] object PythonRDD {
file.close()
}
- def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
- rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
+ def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
+ implicit val cm : ClassManifest[T] = rdd.elementClassManifest
+ rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
+ }
}
private object Pickle {
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
index 7160fc05fc..6659e53b25 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
@@ -1,6 +1,6 @@
package spark.deploy
-private[spark] class JobDescription(
+private[spark] class ApplicationDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
@@ -10,5 +10,5 @@ private[spark] class JobDescription(
val user = System.getProperty("user.name", "<unknown>")
- override def toString: String = "JobDescription(" + name + ")"
+ override def toString: String = "ApplicationDescription(" + name + ")"
}
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 35f40c6e91..3cbf4fdd98 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -1,7 +1,7 @@
package spark.deploy
import spark.deploy.ExecutorState.ExecutorState
-import spark.deploy.master.{WorkerInfo, JobInfo}
+import spark.deploy.master.{WorkerInfo, ApplicationInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
@@ -23,37 +23,39 @@ case class RegisterWorker(
private[spark]
case class ExecutorStateChanged(
- jobId: String,
+ appId: String,
execId: Int,
state: ExecutorState,
message: Option[String],
exitStatus: Option[Int])
extends DeployMessage
+private[spark] case class Heartbeat(workerId: String) extends DeployMessage
+
// Master to Worker
private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage
-private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
+private[spark] case class KillExecutor(appId: String, execId: Int) extends DeployMessage
private[spark] case class LaunchExecutor(
- jobId: String,
+ appId: String,
execId: Int,
- jobDesc: JobDescription,
+ appDesc: ApplicationDescription,
cores: Int,
memory: Int,
sparkHome: String)
extends DeployMessage
-
// Client to Master
-private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
+private[spark] case class RegisterApplication(appDescription: ApplicationDescription)
+ extends DeployMessage
// Master to Client
private[spark]
-case class RegisteredJob(jobId: String) extends DeployMessage
+case class RegisteredApplication(appId: String) extends DeployMessage
private[spark]
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
@@ -63,7 +65,7 @@ case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String
exitStatus: Option[Int])
private[spark]
-case class JobKilled(message: String)
+case class appKilled(message: String)
// Internal message in Client
@@ -76,8 +78,11 @@ private[spark] case object RequestMasterState
// Master to MasterWebUI
private[spark]
-case class MasterState(uri: String, workers: Array[WorkerInfo], activeJobs: Array[JobInfo],
- completedJobs: Array[JobInfo])
+case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
+ activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+
+ def uri = "spark://" + host + ":" + port
+}
// WorkerWebUI to Worker
private[spark] case object RequestWorkerState
@@ -85,6 +90,6 @@ private[spark] case object RequestWorkerState
// Worker to WorkerWebUI
private[spark]
-case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner],
+case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala
index 732fa08064..38a6ebfc24 100644
--- a/core/src/main/scala/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala
@@ -1,6 +1,6 @@
package spark.deploy
-import master.{JobInfo, WorkerInfo}
+import master.{ApplicationInfo, WorkerInfo}
import worker.ExecutorRunner
import cc.spray.json._
@@ -20,8 +20,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
)
}
- implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] {
- def write(obj: JobInfo) = JsObject(
+ implicit object AppInfoJsonFormat extends RootJsonWriter[ApplicationInfo] {
+ def write(obj: ApplicationInfo) = JsObject(
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
@@ -31,8 +31,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"submitdate" -> JsString(obj.submitDate.toString))
}
- implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] {
- def write(obj: JobDescription) = JsObject(
+ implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] {
+ def write(obj: ApplicationDescription) = JsObject(
"name" -> JsString(obj.name),
"cores" -> JsNumber(obj.cores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
@@ -44,8 +44,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
def write(obj: ExecutorRunner) = JsObject(
"id" -> JsNumber(obj.execId),
"memory" -> JsNumber(obj.memory),
- "jobid" -> JsString(obj.jobId),
- "jobdesc" -> obj.jobDesc.toJson.asJsObject
+ "appid" -> JsString(obj.appId),
+ "appdesc" -> obj.appDesc.toJson.asJsObject
)
}
@@ -57,8 +57,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum),
"memory" -> JsNumber(obj.workers.map(_.memory).sum),
"memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum),
- "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)),
- "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson))
+ "activeapps" -> JsArray(obj.activeApps.toList.map(_.toJson)),
+ "completedapps" -> JsArray(obj.completedApps.toList.map(_.toJson))
)
}
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 2836574ecb..22319a96ca 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -18,35 +18,23 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
- val localIpAddress = Utils.localIpAddress
+ private val localIpAddress = Utils.localIpAddress
+ private val masterActorSystems = ArrayBuffer[ActorSystem]()
+ private val workerActorSystems = ArrayBuffer[ActorSystem]()
- var masterActor : ActorRef = _
- var masterActorSystem : ActorSystem = _
- var masterPort : Int = _
- var masterUrl : String = _
-
- val workerActorSystems = ArrayBuffer[ActorSystem]()
- val workerActors = ArrayBuffer[ActorRef]()
-
- def start() : String = {
+ def start(): String = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
- masterActorSystem = actorSystem
- masterUrl = "spark://" + localIpAddress + ":" + masterPort
- masterActor = masterActorSystem.actorOf(
- Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
+ masterActorSystems += masterSystem
+ val masterUrl = "spark://" + localIpAddress + ":" + masterPort
- /* Start the Slaves */
+ /* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0)
- workerActorSystems += actorSystem
- val actor = actorSystem.actorOf(
- Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)),
- name = "Worker")
- workerActors += actor
+ val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
+ memoryPerWorker, masterUrl, null, Some(workerNum))
+ workerActorSystems += workerSystem
}
return masterUrl
@@ -57,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
// Stop the workers before the master so they don't get upset that it disconnected
workerActorSystems.foreach(_.shutdown())
workerActorSystems.foreach(_.awaitTermination())
- masterActorSystem.shutdown()
- masterActorSystem.awaitTermination()
+ masterActorSystems.foreach(_.shutdown())
+ masterActorSystems.foreach(_.awaitTermination())
}
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index 90fe9508cd..1a95524cf9 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -8,30 +8,25 @@ import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
-import spark.deploy.RegisterJob
+import spark.deploy.RegisterApplication
+import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated
import akka.dispatch.Await
/**
- * The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description,
- * and a listener for job events, and calls back the listener when various events occur.
+ * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
+ * and a listener for cluster events, and calls back the listener when various events occur.
*/
private[spark] class Client(
actorSystem: ActorSystem,
masterUrl: String,
- jobDescription: JobDescription,
+ appDescription: ApplicationDescription,
listener: ClientListener)
extends Logging {
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
-
var actor: ActorRef = null
- var jobId: String = null
-
- if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
- throw new SparkException("Invalid master URL: " + masterUrl)
- }
+ var appId: String = null
class ClientActor extends Actor with Logging {
var master: ActorRef = null
@@ -39,13 +34,11 @@ private[spark] class Client(
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
override def preStart() {
- val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
+ logInfo("Connecting to master " + masterUrl)
try {
- master = context.actorFor(akkaUrl)
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address
- master ! RegisterJob(jobDescription)
+ master ! RegisterApplication(appDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
} catch {
@@ -57,17 +50,17 @@ private[spark] class Client(
}
override def receive = {
- case RegisteredJob(jobId_) =>
- jobId = jobId_
- listener.connected(jobId)
+ case RegisteredApplication(appId_) =>
+ appId = appId_
+ listener.connected(appId)
case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
- val fullId = jobId + "/" + id
+ val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
listener.executorAdded(fullId, workerId, host, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
- val fullId = jobId + "/" + id
+ val fullId = appId + "/" + id
val messageText = message.map(s => " (" + s + ")").getOrElse("")
logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText))
if (ExecutorState.isFinished(state)) {
@@ -114,7 +107,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
- val timeout = 1.seconds
+ val timeout = 5.seconds
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index 7035f4b394..b7008321df 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -8,7 +8,7 @@ package spark.deploy.client
* Users of this API should *not* block inside the callback methods.
*/
private[spark] trait ClientListener {
- def connected(jobId: String): Unit
+ def connected(appId: String): Unit
def disconnected(): Unit
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index 8764c400e2..dc004b59ca 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -2,13 +2,13 @@ package spark.deploy.client
import spark.util.AkkaUtils
import spark.{Logging, Utils}
-import spark.deploy.{Command, JobDescription}
+import spark.deploy.{Command, ApplicationDescription}
private[spark] object TestClient {
class TestListener extends ClientListener with Logging {
def connected(id: String) {
- logInfo("Connected to master, got job ID " + id)
+ logInfo("Connected to master, got app ID " + id)
}
def disconnected() {
@@ -24,7 +24,7 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
- val desc = new JobDescription(
+ val desc = new ApplicationDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
index a274b21c34..3591a94072 100644
--- a/core/src/main/scala/spark/deploy/master/JobInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
@@ -1,18 +1,18 @@
package spark.deploy.master
-import spark.deploy.JobDescription
+import spark.deploy.ApplicationDescription
import java.util.Date
import akka.actor.ActorRef
import scala.collection.mutable
-private[spark] class JobInfo(
+private[spark] class ApplicationInfo(
val startTime: Long,
val id: String,
- val desc: JobDescription,
+ val desc: ApplicationDescription,
val submitDate: Date,
val driver: ActorRef)
{
- var state = JobState.WAITING
+ var state = ApplicationState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
var coresGranted = 0
var endTime = -1L
@@ -48,7 +48,7 @@ private[spark] class JobInfo(
_retryCount
}
- def markFinished(endState: JobState.Value) {
+ def markFinished(endState: ApplicationState.Value) {
state = endState
endTime = System.currentTimeMillis()
}
diff --git a/core/src/main/scala/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/spark/deploy/master/ApplicationState.scala
new file mode 100644
index 0000000000..15016b388d
--- /dev/null
+++ b/core/src/main/scala/spark/deploy/master/ApplicationState.scala
@@ -0,0 +1,11 @@
+package spark.deploy.master
+
+private[spark] object ApplicationState
+ extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
+
+ type ApplicationState = Value
+
+ val WAITING, RUNNING, FINISHED, FAILED = Value
+
+ val MAX_NUM_RETRY = 10
+}
diff --git a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
index 1db2c32633..48e6055fb5 100644
--- a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
@@ -4,12 +4,12 @@ import spark.deploy.ExecutorState
private[spark] class ExecutorInfo(
val id: Int,
- val job: JobInfo,
+ val application: ApplicationInfo,
val worker: WorkerInfo,
val cores: Int,
val memory: Int) {
var state = ExecutorState.LAUNCHING
- def fullId: String = job.id + "/" + id
+ def fullId: String = application.id + "/" + id
}
diff --git a/core/src/main/scala/spark/deploy/master/JobState.scala b/core/src/main/scala/spark/deploy/master/JobState.scala
deleted file mode 100644
index 2b70cf0191..0000000000
--- a/core/src/main/scala/spark/deploy/master/JobState.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package spark.deploy.master
-
-private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
- type JobState = Value
-
- val WAITING, RUNNING, FINISHED, FAILED = Value
-
- val MAX_NUM_RETRY = 10
-}
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index c618e87cdd..1cd68a2aa6 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -3,6 +3,7 @@ package spark.deploy.master
import akka.actor._
import akka.actor.Terminated
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+import akka.util.duration._
import java.text.SimpleDateFormat
import java.util.Date
@@ -15,21 +16,22 @@ import spark.util.AkkaUtils
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
- val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs
+ val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
- var nextJobNumber = 0
+ var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
val idToWorker = new HashMap[String, WorkerInfo]
val actorToWorker = new HashMap[ActorRef, WorkerInfo]
val addressToWorker = new HashMap[Address, WorkerInfo]
- val jobs = new HashSet[JobInfo]
- val idToJob = new HashMap[String, JobInfo]
- val actorToJob = new HashMap[ActorRef, JobInfo]
- val addressToJob = new HashMap[Address, JobInfo]
+ val apps = new HashSet[ApplicationInfo]
+ val idToApp = new HashMap[String, ApplicationInfo]
+ val actorToApp = new HashMap[ActorRef, ApplicationInfo]
+ val addressToApp = new HashMap[Address, ApplicationInfo]
- val waitingJobs = new ArrayBuffer[JobInfo]
- val completedJobs = new ArrayBuffer[JobInfo]
+ val waitingApps = new ArrayBuffer[ApplicationInfo]
+ val completedApps = new ArrayBuffer[ApplicationInfo]
val masterPublicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
@@ -37,15 +39,16 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
// As a temporary workaround before better ways of configuring memory, we allow users to set
- // a flag that will perform round-robin scheduling across the nodes (spreading out each job
- // among all the nodes) instead of trying to consolidate each job onto a small # of nodes.
- val spreadOutJobs = System.getProperty("spark.deploy.spreadOut", "false").toBoolean
+ // a flag that will perform round-robin scheduling across the nodes (spreading out each app
+ // among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
+ val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "false").toBoolean
override def preStart() {
logInfo("Starting Spark master at spark://" + ip + ":" + port)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
startWebUi()
+ context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis)(timeOutDeadWorkers())
}
def startWebUi() {
@@ -73,92 +76,101 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
}
- case RegisterJob(description) => {
- logInfo("Registering job " + description.name)
- val job = addJob(description, sender)
- logInfo("Registered job " + description.name + " with ID " + job.id)
- waitingJobs += job
+ case RegisterApplication(description) => {
+ logInfo("Registering app " + description.name)
+ val app = addApplication(description, sender)
+ logInfo("Registered app " + description.name + " with ID " + app.id)
+ waitingApps += app
context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredJob(job.id)
+ sender ! RegisteredApplication(app.id)
schedule()
}
- case ExecutorStateChanged(jobId, execId, state, message, exitStatus) => {
- val execOption = idToJob.get(jobId).flatMap(job => job.executors.get(execId))
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
+ val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId))
execOption match {
case Some(exec) => {
exec.state = state
- exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus)
+ exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus)
if (ExecutorState.isFinished(state)) {
- val jobInfo = idToJob(jobId)
- // Remove this executor from the worker and job
+ val appInfo = idToApp(appId)
+ // Remove this executor from the worker and app
logInfo("Removing executor " + exec.fullId + " because it is " + state)
- jobInfo.removeExecutor(exec)
+ appInfo.removeExecutor(exec)
exec.worker.removeExecutor(exec)
// Only retry certain number of times so we don't go into an infinite loop.
- if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) {
+ if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) {
schedule()
} else {
- logError("Job %s with ID %s failed %d times, removing it".format(
- jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
- removeJob(jobInfo)
+ logError("Application %s with ID %s failed %d times, removing it".format(
+ appInfo.desc.name, appInfo.id, appInfo.retryCount))
+ removeApplication(appInfo)
}
}
}
case None =>
- logWarning("Got status update for unknown executor " + jobId + "/" + execId)
+ logWarning("Got status update for unknown executor " + appId + "/" + execId)
+ }
+ }
+
+ case Heartbeat(workerId) => {
+ idToWorker.get(workerId) match {
+ case Some(workerInfo) =>
+ workerInfo.lastHeartbeat = System.currentTimeMillis()
+ case None =>
+ logWarning("Got heartbeat from unregistered worker " + workerId)
}
}
case Terminated(actor) => {
- // The disconnected actor could've been either a worker or a job; remove whichever of
+ // The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
- actorToJob.get(actor).foreach(removeJob)
+ actorToApp.get(actor).foreach(removeApplication)
}
case RemoteClientDisconnected(transport, address) => {
- // The disconnected client could've been either a worker or a job; remove whichever it was
+ // The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
- addressToJob.get(address).foreach(removeJob)
+ addressToApp.get(address).foreach(removeApplication)
}
case RemoteClientShutdown(transport, address) => {
- // The disconnected client could've been either a worker or a job; remove whichever it was
+ // The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
- addressToJob.get(address).foreach(removeJob)
+ addressToApp.get(address).foreach(removeApplication)
}
case RequestMasterState => {
- sender ! MasterState(ip + ":" + port, workers.toArray, jobs.toArray, completedJobs.toArray)
+ sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
}
}
/**
- * Can a job use the given worker? True if the worker has enough memory and we haven't already
- * launched an executor for the job on it (right now the standalone backend doesn't like having
+ * Can an app use the given worker? True if the worker has enough memory and we haven't already
+ * launched an executor for the app on it (right now the standalone backend doesn't like having
* two executors on the same worker).
*/
- def canUse(job: JobInfo, worker: WorkerInfo): Boolean = {
- worker.memoryFree >= job.desc.memoryPerSlave && !worker.hasExecutor(job)
+ def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = {
+ worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app)
}
/**
- * Schedule the currently available resources among waiting jobs. This method will be called
- * every time a new job joins or resource availability changes.
+ * Schedule the currently available resources among waiting apps. This method will be called
+ * every time a new app joins or resource availability changes.
*/
def schedule() {
- // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first job
- // in the queue, then the second job, etc.
- if (spreadOutJobs) {
- // Try to spread out each job among all the nodes, until it has all its cores
- for (job <- waitingJobs if job.coresLeft > 0) {
+ // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
+ // in the queue, then the second app, etc.
+ if (spreadOutApps) {
+ // Try to spread out each app among all the nodes, until it has all its cores
+ for (app <- waitingApps if app.coresLeft > 0) {
val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
- .filter(canUse(job, _)).sortBy(_.coresFree).reverse
+ .filter(canUse(app, _)).sortBy(_.coresFree).reverse
val numUsable = usableWorkers.length
val assigned = new Array[Int](numUsable) // Number of cores to give on each node
- var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum)
+ var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum)
var pos = 0
while (toAssign > 0) {
if (usableWorkers(pos).coresFree - assigned(pos) > 0) {
@@ -170,22 +182,22 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
// Now that we've decided how many cores to give on each node, let's actually give them
for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) {
- val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
- launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
- job.state = JobState.RUNNING
+ val exec = app.addExecutor(usableWorkers(pos), assigned(pos))
+ launchExecutor(usableWorkers(pos), exec, app.desc.sparkHome)
+ app.state = ApplicationState.RUNNING
}
}
}
} else {
- // Pack each job into as few nodes as possible until we've assigned all its cores
+ // Pack each app into as few nodes as possible until we've assigned all its cores
for (worker <- workers if worker.coresFree > 0) {
- for (job <- waitingJobs if job.coresLeft > 0) {
- if (canUse(job, worker)) {
- val coresToUse = math.min(worker.coresFree, job.coresLeft)
+ for (app <- waitingApps if app.coresLeft > 0) {
+ if (canUse(app, worker)) {
+ val coresToUse = math.min(worker.coresFree, app.coresLeft)
if (coresToUse > 0) {
- val exec = job.addExecutor(worker, coresToUse)
- launchExecutor(worker, exec, job.desc.sparkHome)
- job.state = JobState.RUNNING
+ val exec = app.addExecutor(worker, coresToUse)
+ launchExecutor(worker, exec, app.desc.sparkHome)
+ app.state = ApplicationState.RUNNING
}
}
}
@@ -196,8 +208,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
- exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
+ worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
+ exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
@@ -219,54 +231,85 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address
for (exec <- worker.executors.values) {
- exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
- exec.job.executors -= exec.id
+ logInfo("Telling app of lost executor: " + exec.id)
+ exec.application.driver ! ExecutorUpdated(exec.id, ExecutorState.LOST, Some("worker lost"), None)
+ exec.application.removeExecutor(exec)
}
}
- def addJob(desc: JobDescription, driver: ActorRef): JobInfo = {
+ def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val job = new JobInfo(now, newJobId(date), desc, date, driver)
- jobs += job
- idToJob(job.id) = job
- actorToJob(driver) = job
- addressToJob(driver.path.address) = job
- return job
+ val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver)
+ apps += app
+ idToApp(app.id) = app
+ actorToApp(driver) = app
+ addressToApp(driver.path.address) = app
+ return app
}
- def removeJob(job: JobInfo) {
- if (jobs.contains(job)) {
- logInfo("Removing job " + job.id)
- jobs -= job
- idToJob -= job.id
- actorToJob -= job.driver
- addressToWorker -= job.driver.path.address
- completedJobs += job // Remember it in our history
- waitingJobs -= job
- for (exec <- job.executors.values) {
+ def removeApplication(app: ApplicationInfo) {
+ if (apps.contains(app)) {
+ logInfo("Removing app " + app.id)
+ apps -= app
+ idToApp -= app.id
+ actorToApp -= app.driver
+ addressToWorker -= app.driver.path.address
+ completedApps += app // Remember it in our history
+ waitingApps -= app
+ for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
- exec.worker.actor ! KillExecutor(exec.job.id, exec.id)
+ exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
}
- job.markFinished(JobState.FINISHED) // TODO: Mark it as FAILED if it failed
+ app.markFinished(ApplicationState.FINISHED) // TODO: Mark it as FAILED if it failed
schedule()
}
}
- /** Generate a new job ID given a job's submission date */
- def newJobId(submitDate: Date): String = {
- val jobId = "job-%s-%04d".format(DATE_FORMAT.format(submitDate), nextJobNumber)
- nextJobNumber += 1
- jobId
+ /** Generate a new app ID given a app's submission date */
+ def newApplicationId(submitDate: Date): String = {
+ val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber)
+ nextAppNumber += 1
+ appId
+ }
+
+ /** Check for, and remove, any timed-out workers */
+ def timeOutDeadWorkers() {
+ // Copy the workers into an array so we don't modify the hashset while iterating through it
+ val expirationTime = System.currentTimeMillis() - WORKER_TIMEOUT
+ val toRemove = workers.filter(_.lastHeartbeat < expirationTime).toArray
+ for (worker <- toRemove) {
+ logWarning("Removing %s because we got no heartbeat in %d seconds".format(
+ worker.id, WORKER_TIMEOUT))
+ removeWorker(worker)
+ }
}
}
private[spark] object Master {
+ private val systemName = "sparkMaster"
+ private val actorName = "Master"
+ private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
+
+ /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ def toAkkaUrl(sparkUrl: String): String = {
+ sparkUrl match {
+ case sparkUrlRegex(host, port) =>
+ "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ case _ =>
+ throw new SparkException("Invalid master URL: " + sparkUrl)
+ }
+ }
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
+ (actorSystem, boundPort)
+ }
}
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index a01774f511..54faa375fb 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -40,35 +40,27 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
}
}
} ~
- path("job") {
- parameters("jobId", 'format ?) {
- case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
+ path("app") {
+ parameters("appId", 'format ?) {
+ case (appId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
- val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => null
- }
- }
+ val appInfo = for (masterState <- future.mapTo[MasterState]) yield {
+ masterState.activeApps.find(_.id == appId).getOrElse({
+ masterState.completedApps.find(_.id == appId).getOrElse(null)
+ })
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
- ctx.complete(jobInfo.mapTo[JobInfo])
+ ctx.complete(appInfo.mapTo[ApplicationInfo])
}
- case (jobId, _) =>
+ case (appId, _) =>
completeWith {
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
-
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => null
- }
- }
+ val app = masterState.activeApps.find(_.id == appId).getOrElse({
+ masterState.completedApps.find(_.id == appId).getOrElse(null)
+ })
+ spark.deploy.master.html.app_details.render(app)
}
}
}
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index 5a7f5fef8a..23df1bb463 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -18,6 +18,8 @@ private[spark] class WorkerInfo(
var coresUsed = 0
var memoryUsed = 0
+ var lastHeartbeat = System.currentTimeMillis()
+
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
@@ -35,8 +37,8 @@ private[spark] class WorkerInfo(
}
}
- def hasExecutor(job: JobInfo): Boolean = {
- executors.values.exists(_.job == job)
+ def hasExecutor(app: ApplicationInfo): Boolean = {
+ executors.values.exists(_.application == app)
}
def webUiAddress : String = {
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index f5ff267d44..de11771c8e 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -1,7 +1,7 @@
package spark.deploy.worker
import java.io._
-import spark.deploy.{ExecutorState, ExecutorStateChanged, JobDescription}
+import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription}
import akka.actor.ActorRef
import spark.{Utils, Logging}
import java.net.{URI, URL}
@@ -14,9 +14,9 @@ import spark.deploy.ExecutorStateChanged
* Manages the execution of one executor process.
*/
private[spark] class ExecutorRunner(
- val jobId: String,
+ val appId: String,
val execId: Int,
- val jobDesc: JobDescription,
+ val appDesc: ApplicationDescription,
val cores: Int,
val memory: Int,
val worker: ActorRef,
@@ -26,7 +26,7 @@ private[spark] class ExecutorRunner(
val workDir: File)
extends Logging {
- val fullId = jobId + "/" + execId
+ val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
var shutdownHook: Thread = null
@@ -60,7 +60,7 @@ private[spark] class ExecutorRunner(
process.destroy()
process.waitFor()
}
- worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None, None)
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@@ -74,10 +74,10 @@ private[spark] class ExecutorRunner(
}
def buildCommandSeq(): Seq[String] = {
- val command = jobDesc.command
- val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run";
+ val command = appDesc.command
+ val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"
val runScript = new File(sparkHome, script).getCanonicalPath
- Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables)
+ Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/** Spawn a thread that will redirect a given stream to a file */
@@ -96,12 +96,12 @@ private[spark] class ExecutorRunner(
}
/**
- * Download and run the executor described in our JobDescription
+ * Download and run the executor described in our ApplicationDescription
*/
def fetchAndRunExecutor() {
try {
// Create the executor's working directory
- val executorDir = new File(workDir, jobId + "/" + execId)
+ val executorDir = new File(workDir, appId + "/" + execId)
if (!executorDir.mkdirs()) {
throw new IOException("Failed to create directory " + executorDir)
}
@@ -110,11 +110,10 @@ private[spark] class ExecutorRunner(
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
val env = builder.environment()
- for ((key, value) <- jobDesc.command.environment) {
+ for ((key, value) <- appDesc.command.environment) {
env.put(key, value)
}
- env.put("SPARK_CORES", cores.toString)
- env.put("SPARK_MEMORY", memory.toString)
+ env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
@@ -129,7 +128,7 @@ private[spark] class ExecutorRunner(
// times on the same machine.
val exitCode = process.waitFor()
val message = "Command exited with code " + exitCode
- worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message),
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
Some(exitCode))
} catch {
case interrupted: InterruptedException =>
@@ -141,7 +140,7 @@ private[spark] class ExecutorRunner(
process.destroy()
}
val message = e.getClass + ": " + e.getMessage
- worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message), None)
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
}
}
}
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 8b41620d98..2bbc931316 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -1,19 +1,18 @@
package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import akka.actor.{ActorRef, Props, Actor}
+import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.util.duration._
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import spark.deploy._
-import akka.remote.RemoteClientLifeCycleEvent
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.text.SimpleDateFormat
import java.util.Date
-import akka.remote.RemoteClientShutdown
-import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker
import spark.deploy.LaunchExecutor
import spark.deploy.RegisterWorkerFailed
-import akka.actor.Terminated
+import spark.deploy.master.Master
import java.io.File
private[spark] class Worker(
@@ -27,7 +26,9 @@ private[spark] class Worker(
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
+
+ // Send a heartbeat every (heartbeat timeout) / 4 milliseconds
+ val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
var master: ActorRef = null
var masterWebUiUrl : String = ""
@@ -48,11 +49,7 @@ private[spark] class Worker(
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
- workDir = if (workDirPath != null) {
- new File(workDirPath)
- } else {
- new File(sparkHome, "work")
- }
+ workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@@ -68,8 +65,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
- val envVar = System.getenv("SPARK_HOME")
- sparkHome = new File(if (envVar == null) "." else envVar)
+ sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
connectToMaster()
@@ -77,24 +73,15 @@ private[spark] class Worker(
}
def connectToMaster() {
- masterUrl match {
- case MASTER_REGEX(masterHost, masterPort) => {
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
- try {
- master = context.actorFor(akkaUrl)
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
- } catch {
- case e: Exception =>
- logError("Failed to connect to master", e)
- System.exit(1)
- }
- }
-
- case _ =>
- logError("Invalid master URL: " + masterUrl)
+ logInfo("Connecting to master " + masterUrl)
+ try {
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
+ master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to master", e)
System.exit(1)
}
}
@@ -114,24 +101,27 @@ private[spark] class Worker(
case RegisteredWorker(url) =>
masterWebUiUrl = url
logInfo("Successfully registered with master")
+ context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
+ master ! Heartbeat(workerId)
+ }
case RegisterWorkerFailed(message) =>
logError("Worker registration failed: " + message)
System.exit(1)
- case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
- logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
+ case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
+ logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
- jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
- executors(jobId + "/" + execId) = manager
+ appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
+ executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
- master ! ExecutorStateChanged(jobId, execId, ExecutorState.RUNNING, None, None)
+ master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
- case ExecutorStateChanged(jobId, execId, state, message, exitStatus) =>
- master ! ExecutorStateChanged(jobId, execId, state, message, exitStatus)
- val fullId = jobId + "/" + execId
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
+ master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
logInfo("Executor " + fullId + " finished with state " + state +
@@ -143,8 +133,8 @@ private[spark] class Worker(
memoryUsed -= executor.memory
}
- case KillExecutor(jobId, execId) =>
- val fullId = jobId + "/" + execId
+ case KillExecutor(appId, execId) =>
+ val fullId = appId + "/" + execId
executors.get(fullId) match {
case Some(executor) =>
logInfo("Asked to kill executor " + fullId)
@@ -157,7 +147,7 @@ private[spark] class Worker(
masterDisconnected()
case RequestWorkerState => {
- sender ! WorkerState(ip + ":" + port, workerId, executors.values.toList,
+ sender ! WorkerState(ip, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
}
@@ -183,11 +173,19 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
- args.master, args.workDir)),
- name = "Worker")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
+ args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
+ masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ // The LocalSparkCluster runs multiple local sparkWorkerX actor systems
+ val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
+ masterUrl, workDir)), name = "Worker")
+ (actorSystem, boundPort)
+ }
+
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 37524a7c82..08f02bad80 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -92,7 +92,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
"Options:\n" +
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
- " -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" +
+ " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index ef81f072a3..135cc2e86c 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -41,9 +41,9 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
}
} ~
path("log") {
- parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) =>
+ parameters("appId", "executorId", "logType") { (appId, executorId, logType) =>
respondWithMediaType(cc.spray.http.MediaTypes.`text/plain`) {
- getFromFileName("work/" + jobId + "/" + executorId + "/" + logType)
+ getFromFileName("work/" + appId + "/" + executorId + "/" + logType)
}
}
} ~
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index e45288ff53..9a82c3054c 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -4,16 +4,15 @@ import java.nio.ByteBuffer
import spark.Logging
import spark.TaskState.TaskState
import spark.util.AkkaUtils
-import akka.actor.{ActorRef, Actor, Props}
+import akka.actor.{ActorRef, Actor, Props, Terminated}
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue}
-import akka.remote.RemoteClientLifeCycleEvent
import spark.scheduler.cluster._
import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterExecutor
-
private[spark] class StandaloneExecutorBackend(
executor: Executor,
driverUrl: String,
@@ -27,17 +26,11 @@ private[spark] class StandaloneExecutorBackend(
var driver: ActorRef = null
override def preStart() {
- try {
- logInfo("Connecting to driver: " + driverUrl)
- driver = context.actorFor(driverUrl)
- driver ! RegisterExecutor(executorId, hostname, cores)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(driver) // Doesn't work with remote actors, but useful for testing
- } catch {
- case e: Exception =>
- logError("Failed to connect to driver", e)
- System.exit(1)
- }
+ logInfo("Connecting to driver: " + driverUrl)
+ driver = context.actorFor(driverUrl)
+ driver ! RegisterExecutor(executorId, hostname, cores)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
override def receive = {
@@ -52,6 +45,10 @@ private[spark] class StandaloneExecutorBackend(
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
+
+ case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+ logError("Driver terminated or disconnected! Shutting down.")
+ System.exit(1)
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
@@ -71,8 +68,9 @@ private[spark] object StandaloneExecutorBackend {
}
def main(args: Array[String]) {
- if (args.length != 4) {
- System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores>")
+ if (args.length < 4) {
+ //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
+ System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index c7f226044d..b6ec664d7e 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -66,31 +66,28 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
- val thisInstance = this
val selectorThread = new Thread("connection-manager-thread") {
- override def run() {
- thisInstance.run()
- }
+ override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
selectorThread.start()
- def run() {
+ private def run() {
try {
while(!selectorThread.isInterrupted) {
- for( (connectionManagerId, sendingConnection) <- connectionRequests) {
+ for ((connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
- while(!sendMessageRequests.isEmpty) {
+ while (!sendMessageRequests.isEmpty) {
val (message, connection) = sendMessageRequests.dequeue
connection.send(message)
}
}
- while(!keyInterestChangeRequests.isEmpty) {
+ while (!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
val connection = connectionsByKey(key)
val lastOps = key.interestOps()
@@ -126,14 +123,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (key.isValid) {
if (key.isAcceptable) {
acceptConnection(key)
- } else
- if (key.isConnectable) {
+ } else if (key.isConnectable) {
connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
- } else
- if (key.isReadable) {
+ } else if (key.isReadable) {
connectionsByKey(key).read()
- } else
- if (key.isWritable) {
+ } else if (key.isWritable) {
connectionsByKey(key).write()
}
}
@@ -144,7 +138,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- def acceptConnection(key: SelectionKey) {
+ private def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
val newChannel = serverChannel.accept()
val newConnection = new ReceivingConnection(newChannel, selector)
@@ -154,7 +148,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
}
- def addConnection(connection: Connection) {
+ private def addConnection(connection: Connection) {
connectionsByKey += ((connection.key, connection))
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
@@ -165,7 +159,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
connection.onClose(removeConnection)
}
- def removeConnection(connection: Connection) {
+ private def removeConnection(connection: Connection) {
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
@@ -222,16 +216,16 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- def handleConnectionError(connection: Connection, e: Exception) {
+ private def handleConnectionError(connection: Connection, e: Exception) {
logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
removeConnection(connection)
}
- def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
keyInterestChangeRequests += ((connection.key, ops))
}
- def receiveMessage(connection: Connection, message: Message) {
+ private def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
@@ -351,7 +345,6 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private[spark] object ConnectionManager {
def main(args: Array[String]) {
-
val manager = new ConnectionManager(9999)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
index 42f46e06ed..de2dce161a 100644
--- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -20,7 +20,7 @@ private[spark] class ApproximateActionListener[T, U, R](
extends JobListener {
val startTime = System.currentTimeMillis()
- val totalTasks = rdd.splits.size
+ val totalTasks = rdd.partitions.size
var finishedTasks = 0
var failure: Option[Exception] = None // Set if the job has failed (permanently)
var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
@@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
- // Notify any waiting thread that may have called getResult
+ // Notify any waiting thread that may have called awaitResult
this.notifyAll()
}
}
@@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
* Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done.
*/
- def getResult(): PartialResult[R] = synchronized {
+ def awaitResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout
while (true) {
val time = System.currentTimeMillis()
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index 2c022f88e0..7348c4f15b 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,9 +1,9 @@
package spark.rdd
import scala.collection.mutable.HashMap
-import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext}
+import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
-private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
+private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
val index = idx
}
@@ -11,10 +11,6 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
- @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
- new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
- }).toArray
-
@transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
@@ -22,11 +18,14 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
HashMap(blockIds.zip(locations):_*)
}
- override def getSplits = splits_
+ override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
+ new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
+ }).toArray
- override def compute(split: Split, context: TaskContext): Iterator[T] = {
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
- val blockId = split.asInstanceOf[BlockRDDSplit].blockId
+ val blockId = split.asInstanceOf[BlockRDDPartition].blockId
blockManager.get(blockId) match {
case Some(block) => block.asInstanceOf[Iterator[T]]
case None =>
@@ -34,11 +33,8 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def getPreferredLocations(split: Split) =
- locations_(split.asInstanceOf[BlockRDDSplit].blockId)
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ locations_(split.asInstanceOf[BlockRDDPartition].blockId)
- override def clearDependencies() {
- splits_ = null
- }
}
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 0f9ca06531..38600b8be4 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -5,22 +5,22 @@ import spark._
private[spark]
-class CartesianSplit(
+class CartesianPartition(
idx: Int,
@transient rdd1: RDD[_],
@transient rdd2: RDD[_],
s1Index: Int,
s2Index: Int
- ) extends Split {
- var s1 = rdd1.splits(s1Index)
- var s2 = rdd2.splits(s2Index)
+ ) extends Partition {
+ var s1 = rdd1.partitions(s1Index)
+ var s2 = rdd2.partitions(s2Index)
override val index: Int = idx
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- s1 = rdd1.splits(s1Index)
- s2 = rdd2.splits(s2Index)
+ s1 = rdd1.partitions(s1Index)
+ s2 = rdd2.partitions(s2Index)
oos.defaultWriteObject()
}
}
@@ -33,39 +33,40 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
extends RDD[Pair[T, U]](sc, Nil)
with Serializable {
- val numSplitsInRdd2 = rdd2.splits.size
+ val numPartitionsInRdd2 = rdd2.partitions.size
- override def getSplits: Array[Split] = {
+ override def getPartitions: Array[Partition] = {
// create the cross product split
- val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
- for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
- val idx = s1.index * numSplitsInRdd2 + s2.index
- array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index)
+ val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size)
+ for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
+ val idx = s1.index * numPartitionsInRdd2 + s2.index
+ array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)
}
array
}
- override def getPreferredLocations(split: Split) = {
- val currSplit = split.asInstanceOf[CartesianSplit]
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
- override def compute(split: Split, context: TaskContext) = {
- val currSplit = split.asInstanceOf[CartesianSplit]
+ override def compute(split: Partition, context: TaskContext) = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
for (x <- rdd1.iterator(currSplit.s1, context);
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) {
- def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
+ def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2)
},
new NarrowDependency(rdd2) {
- def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
+ def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2)
}
)
override def clearDependencies() {
+ super.clearDependencies()
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 96b593ba7c..36bfb0355e 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -9,7 +9,7 @@ import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
-private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
+private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
@@ -20,29 +20,27 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
- @transient val splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val dirContents = fs.listStatus(new Path(checkpointPath))
val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
- val numSplits = splitFiles.size
+ val numPartitions = splitFiles.size
if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
+ !splitFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1))) {
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
}
- Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
+ Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
checkpointData = Some(new RDDCheckpointData[T](this))
checkpointData.get.cpFile = Some(checkpointPath)
- override def getSplits = splits_
-
- override def getPreferredLocations(split: Split): Seq[String] = {
+ override def getPreferredLocations(split: Partition): Seq[String] = {
val status = fs.getFileStatus(new Path(checkpointPath))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
- override def compute(split: Split, context: TaskContext): Iterator[T] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
CheckpointRDD.readFromFile(file, context)
}
@@ -109,7 +107,7 @@ private[spark] object CheckpointRDD extends Logging {
deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}
- // Test whether CheckpointRDD generate expected number of splits despite
+ // Test whether CheckpointRDD generate expected number of partitions despite
// each split file having multiple blocks. This needs to be run on a
// cluster (mesos or standalone) using HDFS.
def main(args: Array[String]) {
@@ -122,8 +120,8 @@ private[spark] object CheckpointRDD extends Logging {
val fs = path.getFileSystem(new Configuration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
- assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
- assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+ assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
+ assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
fs.delete(path)
}
}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 8fafd27bb6..5200fb6b65 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -5,7 +5,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
+import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -14,13 +14,13 @@ private[spark] sealed trait CoGroupSplitDep extends Serializable
private[spark] case class NarrowCoGroupSplitDep(
rdd: RDD[_],
splitIndex: Int,
- var split: Split
+ var split: Partition
) extends CoGroupSplitDep {
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- split = rdd.splits(splitIndex)
+ split = rdd.partitions(splitIndex)
oos.defaultWriteObject()
}
}
@@ -28,7 +28,7 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
-class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
+class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
@@ -40,50 +40,47 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
- extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
- val aggr = new CoGroupAggregator
+ private val aggr = new CoGroupAggregator
- @transient var deps_ = {
- val deps = new ArrayBuffer[Dependency[_]]
- for ((rdd, index) <- rdds.zipWithIndex) {
+ override def getDependencies: Seq[Dependency[_]] = {
+ rdds.map { rdd =>
if (rdd.partitioner == Some(part)) {
logInfo("Adding one-to-one dependency with " + rdd)
- deps += new OneToOneDependency(rdd)
+ new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
+ new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
}
}
- deps.toList
}
- override def getDependencies = deps_
-
- @transient var splits_ : Array[Split] = {
- val array = new Array[Split](part.numPartitions)
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](part.numPartitions)
for (i <- 0 until array.size) {
- array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
+ // Each CoGroupPartition will have a dependency per contributing RDD
+ array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
+ // Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
+ new ShuffleCoGroupSplitDep(s.shuffleId)
case _ =>
- new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep
+ new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
}.toList)
}
array
}
- override def getSplits = splits_
-
override val partitioner = Some(part)
- override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
- val split = s.asInstanceOf[CoGroupSplit]
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
+ // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k)
@@ -96,7 +93,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
@@ -104,21 +101,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
- def mergePair(pair: (K, Seq[Any])) {
- val mySeq = getSeq(pair._1)
- for (v <- pair._2)
- mySeq(depNum) += v
- }
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
+ for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
+ getSeq(k)(depNum) ++= vs
+ }
}
}
JavaConversions.mapAsScalaMap(map).iterator
}
override def clearDependencies() {
- deps_ = null
- splits_ = null
+ super.clearDependencies()
rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 4c57434b65..0d16cf6e85 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,19 +1,19 @@
package spark.rdd
-import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext}
+import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
-private[spark] case class CoalescedRDDSplit(
+private[spark] case class CoalescedRDDPartition(
index: Int,
@transient rdd: RDD[_],
parentsIndices: Array[Int]
- ) extends Split {
- var parents: Seq[Split] = parentsIndices.map(rdd.splits(_))
+ ) extends Partition {
+ var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- parents = parentsIndices.map(rdd.splits(_))
+ parents = parentsIndices.map(rdd.partitions(_))
oos.defaultWriteObject()
}
}
@@ -31,33 +31,34 @@ class CoalescedRDD[T: ClassManifest](
maxPartitions: Int)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
- override def getSplits: Array[Split] = {
- val prevSplits = prev.splits
+ override def getPartitions: Array[Partition] = {
+ val prevSplits = prev.partitions
if (prevSplits.length < maxPartitions) {
- prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
+ prevSplits.map(_.index).map{idx => new CoalescedRDDPartition(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
- new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray)
+ new CoalescedRDDPartition(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}
}
- override def compute(split: Split, context: TaskContext): Iterator[T] = {
- split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ split.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentSplit =>
firstParent[T].iterator(parentSplit, context)
}
}
- override def getDependencies: Seq[Dependency[_]] = List(
- new NarrowDependency(prev) {
+ override def getDependencies: Seq[Dependency[_]] = {
+ Seq(new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
- splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
- }
- )
+ partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices
+ })
+ }
override def clearDependencies() {
+ super.clearDependencies()
prev = null
}
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index 6dbe235bd9..c84ec39d21 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -1,16 +1,16 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{OneToOneDependency, RDD, Partition, TaskContext}
private[spark] class FilteredRDD[T: ClassManifest](
prev: RDD[T],
f: T => Boolean)
extends RDD[T](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[T].iterator(split, context).filter(f)
}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 1b604c66e2..8ebc778925 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,6 +1,6 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
@@ -9,8 +9,8 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
f: T => TraversableOnce[U])
extends RDD[U](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[T].iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index 051bffed19..e16c7ba881 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,12 +1,12 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
extends RDD[Array[T]](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
Array(firstParent[T].iterator(split, context).toArray).iterator
}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index f547f53812..8139a2a40c 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -15,14 +15,14 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
+import spark.{Dependency, RDD, SerializableWritable, SparkContext, Partition, TaskContext}
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
-private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
- extends Split {
+private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
+ extends Partition {
val inputSplit = new SerializableWritable[InputSplit](s)
@@ -45,15 +45,14 @@ class HadoopRDD[K, V](
extends RDD[(K, V)](sc, Nil) {
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
- val confBroadcast = sc.broadcast(new SerializableWritable(conf))
+ private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
- @transient
- val splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val inputFormat = createInputFormat(conf)
val inputSplits = inputFormat.getSplits(conf, minSplits)
- val array = new Array[Split](inputSplits.size)
+ val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
- array(i) = new HadoopSplit(id, i, inputSplits(i))
+ array(i) = new HadoopPartition(id, i, inputSplits(i))
}
array
}
@@ -63,10 +62,8 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def getSplits = splits_
-
- override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopSplit]
+ override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopPartition]
var reader: RecordReader[K, V] = null
val conf = confBroadcast.value.value
@@ -109,9 +106,9 @@ class HadoopRDD[K, V](
}
}
- override def getPreferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Partition): Seq[String] = {
// TODO: Filtering out "localhost" in case of file:// URLs
- val hadoopSplit = split.asInstanceOf[HadoopSplit]
+ val hadoopSplit = split.asInstanceOf[HadoopPartition]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index 073f7d7d2a..d283c5b2bb 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,6 +1,6 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
@@ -13,8 +13,8 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val partitioner =
if (preservesPartitioning) firstParent[T].partitioner else None
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
f(firstParent[T].iterator(split, context))
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala
index 2ddc3d01b6..afb7504ba1 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala
@@ -1,24 +1,24 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
/**
- * A variant of the MapPartitionsRDD that passes the split index into the
+ * A variant of the MapPartitionsRDD that passes the partition index into the
* closure. This can be used to generate or collect partition specific
* information such as the number of tuples in a partition.
*/
private[spark]
-class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
override val partitioner = if (preservesPartitioning) prev.partitioner else None
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
f(split.index, firstParent[T].iterator(split, context))
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 5466c9c657..af07311b6d 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,13 +1,13 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[T].iterator(split, context).map(f)
}
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index c3b155fcbd..ebd4c3f0e2 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -7,12 +7,12 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
+import spark.{Dependency, RDD, SerializableWritable, SparkContext, Partition, TaskContext}
private[spark]
-class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
- extends Split {
+class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
+ extends Partition {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
@@ -29,7 +29,7 @@ class NewHadoopRDD[K, V](
with HadoopMapReduceUtil {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
- val confBroadcast = sc.broadcast(new SerializableWritable(conf))
+ private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
private val jobtrackerId: String = {
@@ -39,21 +39,19 @@ class NewHadoopRDD[K, V](
@transient private val jobId = new JobID(jobtrackerId, id)
- @transient private val splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
- val result = new Array[Split](rawSplits.size)
+ val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
- result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
}
result
}
- override def getSplits = splits_
-
- override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopSplit]
+ override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopPartition]
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
@@ -85,8 +83,8 @@ class NewHadoopRDD[K, V](
}
}
- override def getPreferredLocations(split: Split) = {
- val theSplit = split.asInstanceOf[NewHadoopSplit]
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val theSplit = split.asInstanceOf[NewHadoopPartition]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
}
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
index 10adcd53ec..07585a88ce 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
@@ -1,28 +1,29 @@
-package spark
+package spark.rdd
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
+import spark.{RDD, TaskContext, SparkContext, Partition}
-private[spark] class ParallelCollectionSplit[T: ClassManifest](
+private[spark] class ParallelCollectionPartition[T: ClassManifest](
val rddId: Long,
val slice: Int,
values: Seq[T])
- extends Split with Serializable {
+ extends Partition with Serializable {
def iterator: Iterator[T] = values.iterator
override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
override def equals(other: Any): Boolean = other match {
- case that: ParallelCollectionSplit[_] => (this.rddId == that.rddId && this.slice == that.slice)
+ case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice)
case _ => false
}
override val index: Int = slice
}
-private[spark] class ParallelCollection[T: ClassManifest](
+private[spark] class ParallelCollectionRDD[T: ClassManifest](
@transient sc: SparkContext,
@transient data: Seq[T],
numSlices: Int,
@@ -33,26 +34,20 @@ private[spark] class ParallelCollection[T: ClassManifest](
// instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
- @transient var splits_ : Array[Split] = {
- val slices = ParallelCollection.slice(data, numSlices).toArray
- slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
+ override def getPartitions: Array[Partition] = {
+ val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
+ slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
- override def getSplits = splits_
+ override def compute(s: Partition, context: TaskContext) =
+ s.asInstanceOf[ParallelCollectionPartition[T]].iterator
- override def compute(s: Split, context: TaskContext) =
- s.asInstanceOf[ParallelCollectionSplit[T]].iterator
-
- override def getPreferredLocations(s: Split): Seq[String] = {
+ override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
}
-
- override def clearDependencies() {
- splits_ = null
- }
}
-private object ParallelCollection {
+private object ParallelCollectionRDD {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
index a50ce75171..f2f4fd56d1 100644
--- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
@@ -1,9 +1,9 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
+import spark.{NarrowDependency, RDD, SparkEnv, Partition, TaskContext}
-class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
+class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition {
override val index = idx
}
@@ -16,15 +16,15 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
extends NarrowDependency[T](rdd) {
@transient
- val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
- .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
+ val partitions: Array[Partition] = rdd.partitions.filter(s => partitionFilterFunc(s.index))
+ .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}
/**
- * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
+ * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
@@ -34,9 +34,9 @@ class PartitionPruningRDD[T: ClassManifest](
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
- override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
- split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
+ override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator(
+ split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context)
- override protected def getSplits =
+ override protected def getPartitions: Array[Partition] =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 6631f83510..962a1b21ad 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -8,7 +8,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
-import spark.{RDD, SparkEnv, Split, TaskContext}
+import spark.{RDD, SparkEnv, Partition, TaskContext}
/**
@@ -27,9 +27,9 @@ class PipedRDD[T: ClassManifest](
// using a standard StringTokenizer (i.e. by spaces)
def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext): Iterator[String] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index e24ad23b21..243673f151 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -5,10 +5,10 @@ import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
-class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
+class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
override val index: Int = prev.index
}
@@ -19,18 +19,16 @@ class SampledRDD[T: ClassManifest](
seed: Int)
extends RDD[T](prev) {
- @transient var splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val rg = new Random(seed)
- firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
+ firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt))
}
- override def getSplits = splits_
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev)
- override def getPreferredLocations(split: Split) =
- firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
-
- override def compute(splitIn: Split, context: TaskContext) = {
- val split = splitIn.asInstanceOf[SampledRDDSplit]
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
+ val split = splitIn.asInstanceOf[SampledRDDPartition]
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
@@ -48,8 +46,4 @@ class SampledRDD[T: ClassManifest](
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
-
- override def clearDependencies() {
- splits_ = null
- }
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index d396478673..c2f118305f 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,9 +1,9 @@
package spark.rdd
-import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
+import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
-private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
+private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
override def hashCode(): Int = idx
}
@@ -22,9 +22,11 @@ class ShuffledRDD[K, V](
override val partitioner = Some(part)
- override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
+ override def getPartitions: Array[Partition] = {
+ Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
+ }
- override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 26a2d511f2..2c52a67e22 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,13 +1,13 @@
package spark.rdd
import scala.collection.mutable.ArrayBuffer
-import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import spark.{Dependency, RangeDependency, RDD, SparkContext, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
-private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
- extends Split {
+private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+ extends Partition {
- var split: Split = rdd.splits(splitIndex)
+ var split: Partition = rdd.partitions(splitIndex)
def iterator(context: TaskContext) = rdd.iterator(split, context)
@@ -18,7 +18,7 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- split = rdd.splits(splitIndex)
+ split = rdd.partitions(splitIndex)
oos.defaultWriteObject()
}
}
@@ -28,11 +28,11 @@ class UnionRDD[T: ClassManifest](
@transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
- override def getSplits: Array[Split] = {
- val array = new Array[Split](rdds.map(_.splits.size).sum)
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](rdds.map(_.partitions.size).sum)
var pos = 0
- for (rdd <- rdds; split <- rdd.splits) {
- array(pos) = new UnionSplit(pos, rdd, split.index)
+ for (rdd <- rdds; split <- rdd.partitions) {
+ array(pos) = new UnionPartition(pos, rdd, split.index)
pos += 1
}
array
@@ -42,19 +42,15 @@ class UnionRDD[T: ClassManifest](
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
- deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
- pos += rdd.splits.size
+ deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size)
+ pos += rdd.partitions.size
}
deps
}
- override def compute(s: Split, context: TaskContext): Iterator[T] =
- s.asInstanceOf[UnionSplit[T]].iterator(context)
+ override def compute(s: Partition, context: TaskContext): Iterator[T] =
+ s.asInstanceOf[UnionPartition[T]].iterator(context)
- override def getPreferredLocations(s: Split): Seq[String] =
- s.asInstanceOf[UnionSplit[T]].preferredLocations()
-
- override def clearDependencies() {
- rdds = null
- }
+ override def getPreferredLocations(s: Partition): Seq[String] =
+ s.asInstanceOf[UnionPartition[T]].preferredLocations()
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index e5df6d8c72..e80ec17aa5 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,17 +1,17 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
-private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
+private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
idx: Int,
@transient rdd1: RDD[T],
@transient rdd2: RDD[U]
- ) extends Split {
+ ) extends Partition {
- var split1 = rdd1.splits(idx)
- var split2 = rdd1.splits(idx)
+ var split1 = rdd1.partitions(idx)
+ var split2 = rdd1.partitions(idx)
override val index: Int = idx
def splits = (split1, split2)
@@ -19,8 +19,8 @@ private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- split1 = rdd1.splits(idx)
- split2 = rdd2.splits(idx)
+ split1 = rdd1.partitions(idx)
+ split2 = rdd2.partitions(idx)
oos.defaultWriteObject()
}
}
@@ -29,31 +29,31 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
sc: SparkContext,
var rdd1: RDD[T],
var rdd2: RDD[U])
- extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
- with Serializable {
+ extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) {
- override def getSplits: Array[Split] = {
- if (rdd1.splits.size != rdd2.splits.size) {
+ override def getPartitions: Array[Partition] = {
+ if (rdd1.partitions.size != rdd2.partitions.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
- val array = new Array[Split](rdd1.splits.size)
- for (i <- 0 until rdd1.splits.size) {
- array(i) = new ZippedSplit(i, rdd1, rdd2)
+ val array = new Array[Partition](rdd1.partitions.size)
+ for (i <- 0 until rdd1.partitions.size) {
+ array(i) = new ZippedPartition(i, rdd1, rdd2)
}
array
}
- override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
- val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = {
+ val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits
rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
}
- override def getPreferredLocations(s: Split): Seq[String] = {
- val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits
rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
}
override def clearDependencies() {
+ super.clearDependencies()
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index b130be6a38..bf0837c066 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private[spark]
-class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+class DAGScheduler(
+ taskSched: TaskScheduler,
+ mapOutputTracker: MapOutputTracker,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv)
+ extends TaskSchedulerListener with Logging {
+
+ def this(taskSched: TaskScheduler) {
+ this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ }
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
@@ -66,10 +75,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
- val env = SparkEnv.get
- val mapOutputTracker = env.mapOutputTracker
- val blockManagerMaster = env.blockManager.master
-
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
// sent with every task. When we detect a node failing, we note the current generation number
// and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
@@ -90,16 +95,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
- }
- }.start()
+ def start() {
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+ }
- def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
locations => locations.map(_.ip).toList
}.toArray
@@ -107,7 +114,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
cacheLocs(rdd.id)
}
- def clearCacheLocs() {
+ private def clearCacheLocs() {
cacheLocs.clear()
}
@@ -116,7 +123,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
- def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -131,12 +138,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
if (shuffleDep != None) {
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of partitions is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
@@ -148,14 +155,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
- def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+ private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
- // we can't do it in its constructor because # of splits is unknown
+ // we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@@ -170,25 +177,22 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
parents.toList
}
- def getMissingParentStages(stage: Stage): List[Stage] = {
+ private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
- val locs = getCacheLocs(rdd)
- for (p <- 0 until rdd.splits.size) {
- if (locs(p) == Nil) {
- for (dep <- rdd.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val mapStage = getShuffleMapStage(shufDep, stage.priority)
- if (!mapStage.isAvailable) {
- missing += mapStage
- }
- case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
- }
+ if (getCacheLocs(rdd).contains(Nil)) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ if (!mapStage.isAvailable) {
+ missing += mapStage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
}
}
}
@@ -198,23 +202,45 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList
}
+ /**
+ * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
+ * JobWaiter whose getResult() method will return the result of the job when it is complete.
+ *
+ * The job is assumed to have at least one partition; zero partition jobs should be handled
+ * without a JobSubmitted event.
+ */
+ private[scheduler] def prepareJob[T, U: ClassManifest](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ callSite: String,
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
+ : (JobSubmitted, JobWaiter[U]) =
+ {
+ assert(partitions.size > 0)
+ val waiter = new JobWaiter(partitions.size, resultHandler)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ return (toSubmit, waiter)
+ }
+
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
- allowLocal: Boolean)
- : Array[U] =
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
{
if (partitions.size == 0) {
- return new Array[U](0)
+ return
}
- val waiter = new JobWaiter(partitions.size)
- val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
- waiter.getResult() match {
- case JobSucceeded(results: Seq[_]) =>
- return results.asInstanceOf[Seq[U]].toArray
+ val (toSubmit, waiter) = prepareJob(
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ eventQueue.put(toSubmit)
+ waiter.awaitResult() match {
+ case JobSucceeded => {}
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
@@ -231,92 +257,119 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
{
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val partitions = (0 until rdd.splits.size).toArray
+ val partitions = (0 until rdd.partitions.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
- return listener.getResult() // Will throw an exception if the job fails
+ return listener.awaitResult() // Will throw an exception if the job fails
+ }
+
+ /**
+ * Process one event retrieved from the event queue.
+ * Returns true if we should stop the event loop.
+ */
+ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ val runId = nextRunId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, runId)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+ clearCacheLocs()
+ logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
+ " output partitions (allowLocal=" + allowLocal + ")")
+ logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case ExecutorLost(execId) =>
+ handleExecutorLost(execId)
+
+ case completion: CompletionEvent =>
+ handleTaskCompletion(completion)
+
+ case TaskSetFailed(taskSet, reason) =>
+ abortStage(idToStage(taskSet.stageId), reason)
+
+ case StopDAGScheduler =>
+ // Cancel any active jobs
+ for (job <- activeJobs) {
+ val error = new SparkException("Job cancelled because SparkContext was shut down")
+ job.listener.jobFailed(error)
+ }
+ return true
+ }
+ return false
}
/**
+ * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
+ * the last fetch failure.
+ */
+ private[scheduler] def resubmitFailedStages() {
+ logInfo("Resubmitting failed stages")
+ clearCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+ /**
+ * Check for waiting or failed stages which are now eligible for resubmission.
+ * Ordinarily run on every iteration of the event loop.
+ */
+ private[scheduler] def submitWaitingStages() {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logTrace("Checking for newly runnable parent stages")
+ logTrace("running: " + running)
+ logTrace("waiting: " + waiting)
+ logTrace("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+
+ /**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
- def run() {
+ private def run() {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
- event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
- val runId = nextRunId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- clearCacheLocs()
- logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
- " output partitions")
- logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
- if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
- // Compute very short actions like first() or take() with no parent stages locally.
- runLocally(job)
- } else {
- activeJobs += job
- resultStageToJob(finalStage) = job
- submitStage(finalStage)
- }
-
- case ExecutorLost(execId) =>
- handleExecutorLost(execId)
-
- case completion: CompletionEvent =>
- handleTaskCompletion(completion)
-
- case TaskSetFailed(taskSet, reason) =>
- abortStage(idToStage(taskSet.stageId), reason)
-
- case StopDAGScheduler =>
- // Cancel any active jobs
- for (job <- activeJobs) {
- val error = new SparkException("Job cancelled because SparkContext was shut down")
- job.listener.jobFailed(error)
- }
+ if (event != null) {
+ if (processEvent(event)) {
return
-
- case null =>
- // queue.poll() timed out, ignore it
+ }
}
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- clearCacheLocs()
- val failed2 = failed.toArray
- failed.clear()
- for (stage <- failed2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ resubmitFailedStages()
} else {
- // TODO: We might want to run this less often, when we are sure that something has become
- // runnable that wasn't before.
- logTrace("Checking for newly runnable parent stages")
- logTrace("running: " + running)
- logTrace("waiting: " + waiting)
- logTrace("failed: " + failed)
- val waiting2 = waiting.toArray
- waiting.clear()
- for (stage <- waiting2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ submitWaitingStages()
}
}
}
@@ -326,14 +379,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
- def runLocally(job: ActiveJob) {
+ private def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
try {
SparkEnv.set(env)
val rdd = job.finalStage.rdd
- val split = rdd.splits(job.partitions(0))
+ val split = rdd.partitions(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
@@ -349,13 +402,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
}
- def submitStage(stage: Stage) {
+ /** Submits stage, but first recursively submits any missing parents. */
+ private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
@@ -367,7 +421,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
- def submitMissingTasks(stage: Stage) {
+ /** Called when stage's parents are available and we can now do its task. */
+ private def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@@ -388,7 +443,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
if (tasks.size > 0) {
- logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
+ logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
@@ -407,7 +462,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
- def handleTaskCompletion(event: CompletionEvent) {
+ private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
@@ -492,7 +547,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage)
}
}
@@ -541,12 +596,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Optionally the generation during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
+ private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
failedGeneration(execId) = currentGeneration
logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
- env.blockManager.master.removeExecutor(execId)
+ blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnExecutor(execId)
@@ -567,7 +622,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- def abortStage(failedStage: Stage, reason: String) {
+ private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
@@ -583,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Return true if one of stage's ancestors is target.
*/
- def stageDependsOn(stage: Stage, target: Stage): Boolean = {
+ private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
@@ -610,14 +665,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds.contains(target.rdd)
}
- def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
return cached
}
// If the RDD has some placement preferences (as is the case for input RDDs), get those
- val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
+ val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
if (rddPrefs != Nil) {
return rddPrefs
}
@@ -636,7 +691,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
- def cleanup(cleanupTime: Long) {
+ private def cleanup(cleanupTime: Long) {
var sizeBefore = idToStage.size
idToStage.clearOldValues(cleanupTime)
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
index c4a74e526f..654131ee84 100644
--- a/core/src/main/scala/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -5,5 +5,5 @@ package spark.scheduler
*/
private[spark] sealed trait JobResult
-private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
+private[spark] case object JobSucceeded extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
index b3d4feebe5..3cc6a86345 100644
--- a/core/src/main/scala/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -3,10 +3,12 @@ package spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
- * An object that waits for a DAGScheduler job to complete.
+ * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
+ * results to the given handler function.
*/
-private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
- private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
+private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+ extends JobListener {
+
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
@@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
- taskResults(index) = result
+ resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
- jobResult = JobSucceeded(taskResults)
+ jobResult = JobSucceeded
this.notifyAll()
}
}
@@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
}
}
- def getResult(): JobResult = synchronized {
+ def awaitResult(): JobResult = synchronized {
while (!jobFinished) {
this.wait()
}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 8cd4c661eb..1721f78f48 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -67,7 +67,7 @@ private[spark] class ResultTask[T, U](
var split = if (rdd == null) {
null
} else {
- rdd.splits(partition)
+ rdd.partitions(partition)
}
override def run(attemptId: Long): U = {
@@ -85,7 +85,7 @@ private[spark] class ResultTask[T, U](
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.splits(partition)
+ split = rdd.partitions(partition)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
@@ -107,6 +107,6 @@ private[spark] class ResultTask[T, U](
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
partition = in.readInt()
val outputId = in.readInt()
- split = in.readObject().asInstanceOf[Split]
+ split = in.readObject().asInstanceOf[Partition]
}
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 83641a2a84..59ee3c0a09 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask {
return old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
@@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@@ -86,12 +86,12 @@ private[spark] class ShuffleMapTask(
var split = if (rdd == null) {
null
} else {
- rdd.splits(partition)
+ rdd.partitions(partition)
}
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.splits(partition)
+ split = rdd.partitions(partition)
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
@@ -112,7 +112,7 @@ private[spark] class ShuffleMapTask(
dep = dep_
partition = in.readInt()
generation = in.readLong()
- split = in.readObject().asInstanceOf[Split]
+ split = in.readObject().asInstanceOf[Partition]
}
override def run(attemptId: Long): MapStatus = {
@@ -127,7 +127,6 @@ private[spark] class ShuffleMapTask(
val bucketId = dep.partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
- val bucketIterators = buckets.map(_.iterator)
val compressedSizes = new Array[Byte](numOutputSplits)
@@ -135,7 +134,7 @@ private[spark] class ShuffleMapTask(
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = bucketIterators(i)
+ val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 374114d870..552061e46b 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -28,7 +28,7 @@ private[spark] class Stage(
extends Logging {
val isShuffleMap = shuffleDep != None
- val numPartitions = rdd.splits.size
+ val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 0b4177805b..1e4fbdb874 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -86,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def submitTasks(taskSet: TaskSet) {
+ override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
diff --git a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala
index bba7de6a65..8bf838209f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala
@@ -12,10 +12,10 @@ class ExecutorLossReason(val message: String) {
private[spark]
case class ExecutorExited(val exitCode: Int)
- extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) {
+ extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) {
}
private[spark]
case class SlaveLost(_message: String = "Slave lost")
- extends ExecutorLossReason(_message) {
+ extends ExecutorLossReason(_message) {
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
index ddcd64d7c6..9ac875de3a 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
@@ -1,5 +1,7 @@
package spark.scheduler.cluster
+import spark.Utils
+
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
@@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
+ // Memory used by each executor (in megabytes)
+ protected val executorMemory = {
+ // TODO: Might need to add some extra memory for the non-heap parts of the JVM
+ Option(System.getProperty("spark.executor.memory"))
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ }
+
+
// TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 9760d23072..bb289c9cf3 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -2,14 +2,14 @@ package spark.scheduler.cluster
import spark.{Utils, Logging, SparkContext}
import spark.deploy.client.{Client, ClientListener}
-import spark.deploy.{Command, JobDescription}
+import spark.deploy.{Command, ApplicationDescription}
import scala.collection.mutable.HashMap
private[spark] class SparkDeploySchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
- jobName: String)
+ appName: String)
extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
with ClientListener
with Logging {
@@ -20,16 +20,6 @@ private[spark] class SparkDeploySchedulerBackend(
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
- // Memory used by each executor (in megabytes)
- val executorMemory = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
override def start() {
super.start()
@@ -39,10 +29,11 @@ private[spark] class SparkDeploySchedulerBackend(
StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
- val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
- val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
+ val sparkHome = sc.getSparkHome().getOrElse(
+ throw new IllegalArgumentException("must supply spark home for spark standalone"))
+ val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome)
- client = new Client(sc.env.actorSystem, master, jobDesc, this)
+ client = new Client(sc.env.actorSystem, master, appDesc, this)
client.start()
}
@@ -55,8 +46,8 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
- override def connected(jobId: String) {
- logInfo("Connected to Spark cluster with job ID " + jobId)
+ override def connected(appId: String) {
+ logInfo("Connected to Spark cluster with app ID " + appId)
}
override def disconnected() {
@@ -77,6 +68,6 @@ private[spark] class SparkDeploySchedulerBackend(
case None => SlaveLost(message)
}
logInfo("Executor %s removed: %s".format(executorId, message))
- scheduler.executorLost(executorId, reason)
+ removeExecutor(executorId, reason.toString)
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index da7dcf4b6b..d766067824 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -37,3 +37,6 @@ object StatusUpdate {
// Internal messages in driver
private[spark] case object ReviveOffers extends StandaloneClusterMessage
private[spark] case object StopDriver extends StandaloneClusterMessage
+
+private[spark] case class RemoveExecutor(executorId: String, reason: String)
+ extends StandaloneClusterMessage
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 082022be1c..d606432572 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -68,6 +68,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! true
context.stop(self)
+ case RemoveExecutor(executorId, reason) =>
+ removeExecutor(executorId, reason)
+ sender ! true
+
case Terminated(actor) =>
actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
@@ -100,16 +104,18 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Remove a disconnected slave from the cluster
def removeExecutor(executorId: String, reason: String) {
- logInfo("Slave " + executorId + " disconnected, so removing it")
- val numCores = freeCores(executorId)
- actorToExecutorId -= executorActor(executorId)
- addressToExecutorId -= executorAddress(executorId)
- executorActor -= executorId
- executorHost -= executorId
- freeCores -= executorId
- executorHost -= executorId
- totalCoreCount.addAndGet(-numCores)
- scheduler.executorLost(executorId, SlaveLost(reason))
+ if (executorActor.contains(executorId)) {
+ logInfo("Executor " + executorId + " disconnected, so removing it")
+ val numCores = freeCores(executorId)
+ actorToExecutorId -= executorActor(executorId)
+ addressToExecutorId -= executorAddress(executorId)
+ executorActor -= executorId
+ executorHost -= executorId
+ freeCores -= executorId
+ executorHost -= executorId
+ totalCoreCount.addAndGet(-numCores)
+ scheduler.executorLost(executorId, SlaveLost(reason))
+ }
}
}
@@ -139,7 +145,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
} catch {
case e: Exception =>
- throw new SparkException("Error stopping standalone scheduler's master actor", e)
+ throw new SparkException("Error stopping standalone scheduler's driver actor", e)
}
}
@@ -148,6 +154,18 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
+
+ // Called by subclasses when notified of a lost worker
+ def removeExecutor(executorId: String, reason: String) {
+ try {
+ val timeout = 5.seconds
+ val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
+ Await.result(future, timeout)
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error notifying standalone scheduler's driver actor", e)
+ }
+ }
}
private[spark] object StandaloneSchedulerBackend {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 26201ad0dd..3dabdd76b1 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -17,10 +17,7 @@ import java.nio.ByteBuffer
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
-private[spark] class TaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet)
- extends Logging {
+private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
@@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
}
// Add a task to all the pending-task lists that it should be on.
- def addPendingTask(index: Int) {
+ private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
@@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
- def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
@@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
- def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
@@ -137,7 +134,7 @@ private[spark] class TaskSetManager(
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
- def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
val hostsAlive = sched.hostsAlive
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find {
@@ -162,7 +159,7 @@ private[spark] class TaskSetManager(
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
- def findTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host))
if (localTask != None) {
return localTask
@@ -184,7 +181,7 @@ private[spark] class TaskSetManager(
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
- def isPreferredLocation(task: Task[_], host: String): Boolean = {
+ private def isPreferredLocation(task: Task[_], host: String): Boolean = {
val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty)
}
@@ -335,7 +332,7 @@ private[spark] class TaskSetManager(
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 9ff7c02097..482d1cc853 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
- logInfo("Running task " + idInJob)
+ logInfo("Running " + task)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
- logInfo("Finished task " + idInJob)
+ logInfo("Finished " + task)
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index 7bf56a05d6..f4a2994b6d 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -28,23 +28,13 @@ private[spark] class CoarseMesosSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
- frameworkName: String)
+ appName: String)
extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
with MScheduler
with Logging {
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
- // Memory used by each executor (in megabytes)
- val executorMemory = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@@ -86,7 +76,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -249,7 +239,11 @@ private[spark] class CoarseMesosSchedulerBackend(
override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
logInfo("Mesos slave lost: " + slaveId.getValue)
synchronized {
- slaveIdsWithExecutors -= slaveId.getValue
+ if (slaveIdsWithExecutors.contains(slaveId.getValue)) {
+ // Note that the slave ID corresponds to the executor ID on that slave
+ slaveIdsWithExecutors -= slaveId.getValue
+ removeExecutor(slaveId.getValue, "Mesos slave lost")
+ }
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index eab1c60e0b..ca7fab4cc5 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -24,21 +24,11 @@ private[spark] class MesosSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
- frameworkName: String)
+ appName: String)
extends SchedulerBackend
with MScheduler
with Logging {
- // Memory used by each executor (in megabytes)
- val EXECUTOR_MEMORY = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@@ -59,7 +49,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
@@ -89,7 +79,7 @@ private[spark] class MesosSchedulerBackend(
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build())
+ .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
.build()
val command = CommandInfo.newBuilder()
.setValue(execScript)
@@ -161,7 +151,7 @@ private[spark] class MesosSchedulerBackend(
def enoughMemory(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
val slaveId = o.getSlaveId.getValue
- mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
+ mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
}
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index c61fd75c2b..2462721fb8 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -513,7 +513,7 @@ class BlockManager(
}
}
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // Partition local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
@@ -585,7 +585,7 @@ class BlockManager(
resultsGotten += 1
val result = results.take()
bytesInFlight -= result.size
- if (!fetchRequests.isEmpty &&
+ while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
@@ -950,6 +950,7 @@ class BlockManager(
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
+ metadataCleaner.cancel()
logInfo("BlockManager stopped")
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 36398095a2..7389bee150 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -27,8 +27,6 @@ private[spark] class BlockManagerMaster(
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
- val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
- val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val timeout = 10.seconds
var driverActor: ActorRef = {
@@ -117,6 +115,10 @@ private[spark] class BlockManagerMaster(
askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
+ def getStorageStatus: Array[StorageStatus] = {
+ askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
+ }
+
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (driverActor != null) {
diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala
index eda320fa47..9e6721ec17 100644
--- a/core/src/main/scala/spark/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala
@@ -1,13 +1,10 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
-import akka.pattern.ask
import akka.util.Timeout
import akka.util.duration._
-import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
import cc.spray.Directives
-import scala.collection.mutable.ArrayBuffer
import spark.{Logging, SparkContext}
import spark.util.AkkaUtils
import spark.Utils
@@ -48,32 +45,26 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
path("") {
completeWith {
// Request the current storage status from the Master
- val future = blockManagerMaster ? GetStorageStatus
- future.map { status =>
- // Calculate macro-level statistics
- val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
- val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
- val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
- val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
- .reduceOption(_+_).getOrElse(0L)
- val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
- spark.storage.html.index.
- render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
- }
+ val storageStatusList = sc.getExecutorStorageStatus
+ // Calculate macro-level statistics
+ val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
+ val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
+ val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
+ .reduceOption(_+_).getOrElse(0L)
+ val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
+ spark.storage.html.index.
+ render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
}
} ~
path("rdd") {
parameter("id") { id =>
completeWith {
- val future = blockManagerMaster ? GetStorageStatus
- future.map { status =>
- val prefix = "rdd_" + id.toString
- val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
- val filteredStorageStatusList = StorageUtils.
- filterStorageStatusByPrefix(storageStatusList, prefix)
- val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
- spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
- }
+ val prefix = "rdd_" + id.toString
+ val storageStatusList = sc.getExecutorStorageStatus
+ val filteredStorageStatusList = StorageUtils.
+ filterStorageStatusByPrefix(storageStatusList, prefix)
+ val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
+ spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
}
}
} ~
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
index a10e3a95c6..dec47a9d41 100644
--- a/core/src/main/scala/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -1,6 +1,6 @@
package spark.storage
-import spark.SparkContext
+import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
@@ -22,8 +22,13 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
- numPartitions: Int, memSize: Long, diskSize: Long)
-
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
+ override def toString = {
+ import Utils.memoryBytesToString
+ "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
+ storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
+ }
+}
/* Helper methods for storage-related objects */
private[spark]
@@ -38,8 +43,6 @@ object StorageUtils {
/* Given a list of BlockStatus objets, returns information for each RDD */
def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
- // Find all RDD Blocks (ignore broadcast variables)
- val rddBlocks = infos.filterKeys(_.startsWith("rdd"))
// Group by rddId, ignore the partition name
val groupedRddBlocks = infos.groupBy { case(k, v) =>
@@ -56,10 +59,11 @@ object StorageUtils {
// Find the id of the RDD, e.g. rdd_1 => 1
val rddId = rddKey.split("_").last.toInt
// Get the friendly name for the rdd, if available.
- val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey)
- val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel
-
- RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize)
+ val rdd = sc.persistentRdds(rddId)
+ val rddName = Option(rdd.name).getOrElse(rddKey)
+ val rddStorageLevel = rdd.getStorageLevel
+
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize)
}.toArray
}
@@ -75,4 +79,4 @@ object StorageUtils {
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e0fdeffbc4..30aec5a663 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException
* Various utility classes for working with Akka.
*/
private[spark] object AkkaUtils {
+
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
* ActorSystem itself and its port (which is hard to get from Akka).
+ *
+ * Note: the `name` parameter is important, as even if a client sends a message to right
+ * host + port, if the system name is incorrect, Akka will drop the message.
*/
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
@@ -30,6 +34,7 @@ private[spark] object AkkaUtils {
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
+ akka.stdout-loglevel = "ERROR"
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
akka.remote.log-remote-lifecycle-events = on
@@ -41,7 +46,7 @@ private[spark] object AkkaUtils {
akka.actor.default-dispatcher.throughput = %d
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
- val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)
+ val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
index eaff7ae581..a342d378ff 100644
--- a/core/src/main/scala/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -9,12 +9,12 @@ import spark.Logging
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
- val delaySeconds = MetadataCleaner.getDelaySeconds
- val periodSeconds = math.max(10, delaySeconds / 10)
- val timer = new Timer(name + " cleanup timer", true)
+ private val delaySeconds = MetadataCleaner.getDelaySeconds
+ private val periodSeconds = math.max(10, delaySeconds / 10)
+ private val timer = new Timer(name + " cleanup timer", true)
- val task = new TimerTask {
- def run() {
+ private val task = new TimerTask {
+ override def run() {
try {
cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
logInfo("Ran metadata cleaner for " + name)
diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
new file mode 100644
index 0000000000..301a7e2124
--- /dev/null
+++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
@@ -0,0 +1,40 @@
+@(app: spark.deploy.master.ApplicationInfo)
+
+@spark.common.html.layout(title = "Application Details") {
+
+ <!-- Application Details -->
+ <div class="row">
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>ID:</strong> @app.id</li>
+ <li><strong>Description:</strong> @app.desc.name</li>
+ <li><strong>User:</strong> @app.desc.user</li>
+ <li><strong>Cores:</strong>
+ @app.desc.cores
+ (@app.coresGranted Granted
+ @if(app.desc.cores == Integer.MAX_VALUE) {
+
+ } else {
+ , @app.coresLeft
+ }
+ )
+ </li>
+ <li><strong>Memory per Slave:</strong> @app.desc.memoryPerSlave</li>
+ <li><strong>Submit Date:</strong> @app.submitDate</li>
+ <li><strong>State:</strong> @app.state</li>
+ </ul>
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- Executors -->
+ <div class="row">
+ <div class="span12">
+ <h3> Executor Summary </h3>
+ <br/>
+ @executors_table(app.executors.values.toList)
+ </div>
+ </div>
+
+}
diff --git a/core/src/main/twirl/spark/deploy/master/app_row.scala.html b/core/src/main/twirl/spark/deploy/master/app_row.scala.html
new file mode 100644
index 0000000000..feb306f35c
--- /dev/null
+++ b/core/src/main/twirl/spark/deploy/master/app_row.scala.html
@@ -0,0 +1,20 @@
+@(app: spark.deploy.master.ApplicationInfo)
+
+@import spark.Utils
+@import spark.deploy.WebUI.formatDate
+@import spark.deploy.WebUI.formatDuration
+
+<tr>
+ <td>
+ <a href="app?appId=@(app.id)">@app.id</a>
+ </td>
+ <td>@app.desc.name</td>
+ <td>
+ @app.coresGranted
+ </td>
+ <td>@Utils.memoryMegabytesToString(app.desc.memoryPerSlave)</td>
+ <td>@formatDate(app.submitDate)</td>
+ <td>@app.desc.user</td>
+ <td>@app.state.toString()</td>
+ <td>@formatDuration(app.duration)</td>
+</tr>
diff --git a/core/src/main/twirl/spark/deploy/master/job_table.scala.html b/core/src/main/twirl/spark/deploy/master/app_table.scala.html
index d267d6e85e..f789cee0f1 100644
--- a/core/src/main/twirl/spark/deploy/master/job_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/app_table.scala.html
@@ -1,9 +1,9 @@
-@(jobs: Array[spark.deploy.master.JobInfo])
+@(apps: Array[spark.deploy.master.ApplicationInfo])
<table class="table table-bordered table-striped table-condensed sortable">
<thead>
<tr>
- <th>JobID</th>
+ <th>ID</th>
<th>Description</th>
<th>Cores</th>
<th>Memory per Node</th>
@@ -14,8 +14,8 @@
</tr>
</thead>
<tbody>
- @for(j <- jobs) {
- @job_row(j)
+ @for(j <- apps) {
+ @app_row(j)
}
</tbody>
</table>
diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
index 784d692fc2..d2d80fad48 100644
--- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
@@ -9,7 +9,7 @@
<td>@executor.memory</td>
<td>@executor.state</td>
<td>
- <a href="@(executor.worker.webUiAddress)/log?jobId=@(executor.job.id)&executorId=@(executor.id)&logType=stdout">stdout</a>
- <a href="@(executor.worker.webUiAddress)/log?jobId=@(executor.job.id)&executorId=@(executor.id)&logType=stderr">stderr</a>
+ <a href="@(executor.worker.webUiAddress)/log?appId=@(executor.application.id)&executorId=@(executor.id)&logType=stdout">stdout</a>
+ <a href="@(executor.worker.webUiAddress)/log?appId=@(executor.application.id)&executorId=@(executor.id)&logType=stderr">stderr</a>
</td>
-</tr> \ No newline at end of file
+</tr>
diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index 285645c389..ac51a39a51 100644
--- a/core/src/main/twirl/spark/deploy/master/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -2,19 +2,19 @@
@import spark.deploy.master._
@import spark.Utils
-@spark.common.html.layout(title = "Spark Master on " + state.uri) {
-
+@spark.common.html.layout(title = "Spark Master on " + state.host) {
+
<!-- Cluster Details -->
<div class="row">
<div class="span12">
<ul class="unstyled">
- <li><strong>URL:</strong> spark://@(state.uri)</li>
+ <li><strong>URL:</strong> @(state.uri)</li>
<li><strong>Workers:</strong> @state.workers.size </li>
<li><strong>Cores:</strong> @{state.workers.map(_.cores).sum} Total,
@{state.workers.map(_.coresUsed).sum} Used</li>
<li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(state.workers.map(_.memory).sum)} Total,
@{Utils.memoryMegabytesToString(state.workers.map(_.memoryUsed).sum)} Used</li>
- <li><strong>Jobs:</strong> @state.activeJobs.size Running, @state.completedJobs.size Completed </li>
+ <li><strong>Applications:</strong> @state.activeApps.size Running, @state.completedApps.size Completed </li>
</ul>
</div>
</div>
@@ -22,7 +22,7 @@
<!-- Worker Summary -->
<div class="row">
<div class="span12">
- <h3> Cluster Summary </h3>
+ <h3> Workers </h3>
<br/>
@worker_table(state.workers.sortBy(_.id))
</div>
@@ -30,23 +30,23 @@
<hr/>
- <!-- Job Summary (Running) -->
+ <!-- App Summary (Running) -->
<div class="row">
<div class="span12">
- <h3> Running Jobs </h3>
+ <h3> Running Applications </h3>
<br/>
- @job_table(state.activeJobs.sortBy(_.startTime).reverse)
+ @app_table(state.activeApps.sortBy(_.startTime).reverse)
</div>
</div>
<hr/>
- <!-- Job Summary (Completed) -->
+ <!-- App Summary (Completed) -->
<div class="row">
<div class="span12">
- <h3> Completed Jobs </h3>
+ <h3> Completed Applications </h3>
<br/>
- @job_table(state.completedJobs.sortBy(_.endTime).reverse)
+ @app_table(state.completedApps.sortBy(_.endTime).reverse)
</div>
</div>
diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
deleted file mode 100644
index d02a51b214..0000000000
--- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html
+++ /dev/null
@@ -1,40 +0,0 @@
-@(job: spark.deploy.master.JobInfo)
-
-@spark.common.html.layout(title = "Job Details") {
-
- <!-- Job Details -->
- <div class="row">
- <div class="span12">
- <ul class="unstyled">
- <li><strong>ID:</strong> @job.id</li>
- <li><strong>Description:</strong> @job.desc.name</li>
- <li><strong>User:</strong> @job.desc.user</li>
- <li><strong>Cores:</strong>
- @job.desc.cores
- (@job.coresGranted Granted
- @if(job.desc.cores == Integer.MAX_VALUE) {
-
- } else {
- , @job.coresLeft
- }
- )
- </li>
- <li><strong>Memory per Slave:</strong> @job.desc.memoryPerSlave</li>
- <li><strong>Submit Date:</strong> @job.submitDate</li>
- <li><strong>State:</strong> @job.state</li>
- </ul>
- </div>
- </div>
-
- <hr/>
-
- <!-- Executors -->
- <div class="row">
- <div class="span12">
- <h3> Executor Summary </h3>
- <br/>
- @executors_table(job.executors.values.toList)
- </div>
- </div>
-
-}
diff --git a/core/src/main/twirl/spark/deploy/master/job_row.scala.html b/core/src/main/twirl/spark/deploy/master/job_row.scala.html
deleted file mode 100644
index 7c466a6a2c..0000000000
--- a/core/src/main/twirl/spark/deploy/master/job_row.scala.html
+++ /dev/null
@@ -1,20 +0,0 @@
-@(job: spark.deploy.master.JobInfo)
-
-@import spark.Utils
-@import spark.deploy.WebUI.formatDate
-@import spark.deploy.WebUI.formatDuration
-
-<tr>
- <td>
- <a href="job?jobId=@(job.id)">@job.id</a>
- </td>
- <td>@job.desc.name</td>
- <td>
- @job.coresGranted
- </td>
- <td>@Utils.memoryMegabytesToString(job.desc.memoryPerSlave)</td>
- <td>@formatDate(job.submitDate)</td>
- <td>@job.desc.user</td>
- <td>@job.state.toString()</td>
- <td>@formatDuration(job.duration)</td>
-</tr>
diff --git a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
index ea9542461e..dad0a89080 100644
--- a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
@@ -8,13 +8,13 @@
<td>@Utils.memoryMegabytesToString(executor.memory)</td>
<td>
<ul class="unstyled">
- <li><strong>ID:</strong> @executor.jobId</li>
- <li><strong>Name:</strong> @executor.jobDesc.name</li>
- <li><strong>User:</strong> @executor.jobDesc.user</li>
+ <li><strong>ID:</strong> @executor.appId</li>
+ <li><strong>Name:</strong> @executor.appDesc.name</li>
+ <li><strong>User:</strong> @executor.appDesc.user</li>
</ul>
</td>
<td>
- <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stdout">stdout</a>
- <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stderr">stderr</a>
+ <a href="log?appId=@(executor.appId)&executorId=@(executor.execId)&logType=stdout">stdout</a>
+ <a href="log?appId=@(executor.appId)&executorId=@(executor.execId)&logType=stderr">stderr</a>
</td>
</tr>
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index 1d703dae58..c39f769a73 100644
--- a/core/src/main/twirl/spark/deploy/worker/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,8 +1,8 @@
@(worker: spark.deploy.WorkerState)
@import spark.Utils
-@spark.common.html.layout(title = "Spark Worker on " + worker.uri) {
-
+@spark.common.html.layout(title = "Spark Worker on " + worker.host) {
+
<!-- Worker Details -->
<div class="row">
<div class="span12">
@@ -10,12 +10,12 @@
<li><strong>ID:</strong> @worker.workerId</li>
<li><strong>
Master URL:</strong> @worker.masterUrl
- (WebUI at <a href="@worker.masterWebUiUrl">@worker.masterWebUiUrl</a>)
</li>
<li><strong>Cores:</strong> @worker.cores (@worker.coresUsed Used)</li>
<li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(worker.memory)}
(@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</li>
</ul>
+ <p><a href="@worker.masterWebUiUrl">Back to Master</a></p>
</div>
</div>
diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html
index ac7f8c981f..d85addeb17 100644
--- a/core/src/main/twirl/spark/storage/rdd.scala.html
+++ b/core/src/main/twirl/spark/storage/rdd.scala.html
@@ -11,7 +11,11 @@
<strong>Storage Level:</strong>
@(rddInfo.storageLevel.description)
<li>
- <strong>Partitions:</strong>
+ <strong>Cached Partitions:</strong>
+ @(rddInfo.numCachedPartitions)
+ </li>
+ <li>
+ <strong>Total Partitions:</strong>
@(rddInfo.numPartitions)
</li>
<li>
diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html
index af801cf229..a51e64aed0 100644
--- a/core/src/main/twirl/spark/storage/rdd_table.scala.html
+++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html
@@ -6,7 +6,8 @@
<tr>
<th>RDD Name</th>
<th>Storage Level</th>
- <th>Partitions</th>
+ <th>Cached Partitions</th>
+ <th>Fraction Partitions Cached</th>
<th>Size in Memory</th>
<th>Size on Disk</th>
</tr>
@@ -21,7 +22,8 @@
</td>
<td>@(rdd.storageLevel.description)
</td>
- <td>@rdd.numPartitions</td>
+ <td>@rdd.numCachedPartitions</td>
+ <td>@(rdd.numCachedPartitions / rdd.numPartitions.toDouble)</td>
<td>@{Utils.memoryBytesToString(rdd.memSize)}</td>
<td>@{Utils.memoryBytesToString(rdd.diskSize)}</td>
</tr>
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 0b74607fb8..3e5ffa81d6 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -34,7 +34,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+ testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
(i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
@@ -43,14 +43,14 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("ParallelCollection") {
val parCollection = sc.makeRDD(1 to 4, 2)
- val numSplits = parCollection.splits.size
+ val numPartitions = parCollection.partitions.size
parCollection.checkpoint()
assert(parCollection.dependencies === Nil)
val result = parCollection.collect()
assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
assert(parCollection.dependencies != Nil)
- assert(parCollection.splits.length === numSplits)
- assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList)
+ assert(parCollection.partitions.length === numPartitions)
+ assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList)
assert(parCollection.collect() === result)
}
@@ -59,13 +59,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
- val numSplits = blockRDD.splits.size
+ val numPartitions = blockRDD.partitions.size
blockRDD.checkpoint()
val result = blockRDD.collect()
assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
assert(blockRDD.dependencies != Nil)
- assert(blockRDD.splits.length === numSplits)
- assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList)
+ assert(blockRDD.partitions.length === numPartitions)
+ assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList)
assert(blockRDD.collect() === result)
}
@@ -79,9 +79,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("UnionRDD") {
def otherRDD = sc.makeRDD(1 to 10, 1)
- // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed.
+ // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed.
// Current implementation of UnionRDD has transient reference to parent RDDs,
- // so only the splits will reduce in serialized size, not the RDD.
+ // so only the partitions will reduce in serialized size, not the RDD.
testCheckpointing(_.union(otherRDD), false, true)
testParentCheckpointing(_.union(otherRDD), false, true)
}
@@ -91,21 +91,21 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(new CartesianRDD(sc, _, otherRDD))
// Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
- // so only the RDD will reduce in serialized size, not the splits.
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
- // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after
- // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint() // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
- serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
cartesian.count() // do the checkpointing
val splitAfterCheckpoint =
- serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
assert(
(splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
(splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
@@ -114,27 +114,27 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("CoalescedRDD") {
- testCheckpointing(new CoalescedRDD(_, 2))
+ testCheckpointing(_.coalesce(2))
// Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
- // so only the RDD will reduce in serialized size, not the splits.
- testParentCheckpointing(new CoalescedRDD(_, 2), true, false)
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(_.coalesce(2), true, false)
- // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after
- // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
- // Note that this test is very specific to the current implementation of CoalescedRDDSplits
+ // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+ // Note that this test is very specific to the current implementation of CoalescedRDDPartitions
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint() // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
- serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
coalesced.count() // do the checkpointing
val splitAfterCheckpoint =
- serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
assert(
splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
- "CoalescedRDDSplit.parents not updated after parent RDD checkpointed"
+ "CoalescedRDDPartition.parents not updated after parent RDD checkpointed"
)
}
@@ -156,8 +156,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
// Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of ZippedRDDSplit has transient references to parent RDDs,
- // so only the RDD will reduce in serialized size, not the splits.
+ // Current implementation of ZippedRDDPartitions has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the partitions.
testParentCheckpointing(
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
}
@@ -165,21 +165,21 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
/**
* Test checkpointing of the final RDD generated by the given operation. By default,
* this method tests whether the size of serialized RDD has reduced after checkpointing or not.
- * It can also test whether the size of serialized RDD splits has reduced after checkpointing or
- * not, but this is not done by default as usually the splits do not refer to any RDD and
+ * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or
+ * not, but this is not done by default as usually the partitions do not refer to any RDD and
* therefore never store the lineage.
*/
def testCheckpointing[U: ClassManifest](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean = true,
- testRDDSplitSize: Boolean = false
+ testRDDPartitionSize: Boolean = false
) {
// Generate the final RDD using given RDD operation
val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.headOption.orNull
val rddType = operatedRDD.getClass.getSimpleName
- val numSplits = operatedRDD.splits.length
+ val numPartitions = operatedRDD.partitions.length
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
@@ -193,11 +193,11 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// Test whether dependencies have been changed from its earlier parent RDD
assert(operatedRDD.dependencies.head.rdd != parentRDD)
- // Test whether the splits have been changed to the new Hadoop splits
- assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList)
+ // Test whether the partitions have been changed to the new Hadoop partitions
+ assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
- // Test whether the number of splits is same as before
- assert(operatedRDD.splits.length === numSplits)
+ // Test whether the number of partitions is same as before
+ assert(operatedRDD.partitions.length === numPartitions)
// Test whether the data in the checkpointed RDD is same as original
assert(operatedRDD.collect() === result)
@@ -215,18 +215,18 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
)
}
- // Test whether serialized size of the splits has reduced. If the splits
- // do not have any non-transient reference to another RDD or another RDD's splits, it
+ // Test whether serialized size of the partitions has reduced. If the partitions
+ // do not have any non-transient reference to another RDD or another RDD's partitions, it
// does not refer to a lineage and therefore may not reduce in size after checkpointing.
- // However, if the original splits before checkpointing do refer to a parent RDD, the splits
+ // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions
// must be forgotten after checkpointing (to remove all reference to parent RDDs) and
- // replaced with the HadoopSplits of the checkpointed RDD.
- if (testRDDSplitSize) {
- logInfo("Size of " + rddType + " splits "
+ // replaced with the HadooPartitions of the checkpointed RDD.
+ if (testRDDPartitionSize) {
+ logInfo("Size of " + rddType + " partitions "
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
assert(
splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
- "Size of " + rddType + " splits did not reduce after checkpointing " +
+ "Size of " + rddType + " partitions did not reduce after checkpointing " +
"[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
)
}
@@ -235,13 +235,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
/**
* Test whether checkpointing of the parent of the generated RDD also
* truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
- * RDDs splits. So even if the parent RDD is checkpointed and its splits changed,
- * this RDD will remember the splits and therefore potentially the whole lineage.
+ * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
+ * this RDD will remember the partitions and therefore potentially the whole lineage.
*/
def testParentCheckpointing[U: ClassManifest](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean,
- testRDDSplitSize: Boolean
+ testRDDPartitionSize: Boolean
) {
// Generate the final RDD using given RDD operation
val baseRDD = generateLongLineageRDD()
@@ -250,9 +250,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
val rddType = operatedRDD.getClass.getSimpleName
val parentRDDType = parentRDD.getClass.getSimpleName
- // Get the splits and dependencies of the parent in case they're lazily computed
+ // Get the partitions and dependencies of the parent in case they're lazily computed
parentRDD.dependencies
- parentRDD.splits
+ parentRDD.partitions
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
@@ -275,16 +275,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
)
}
- // Test whether serialized size of the splits has reduced because of its parent being
- // checkpointed. If the splits do not have any non-transient reference to another RDD
- // or another RDD's splits, it does not refer to a lineage and therefore may not reduce
- // in size after checkpointing. However, if the splits do refer to the *splits* of a parent
- // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's
- // splits must have changed after checkpointing.
- if (testRDDSplitSize) {
+ // Test whether serialized size of the partitions has reduced because of its parent being
+ // checkpointed. If the partitions do not have any non-transient reference to another RDD
+ // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent
+ // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's
+ // partitions must have changed after checkpointing.
+ if (testRDDPartitionSize) {
assert(
splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
- "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType +
+ "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType +
"[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
)
}
@@ -321,12 +321,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
/**
- * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks
+ * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
* upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
(Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
- Utils.serialize(rdd.splits).length)
+ Utils.serialize(rdd.partitions).length)
}
/**
@@ -347,7 +347,7 @@ object CheckpointSuite {
def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
//println("First = " + first + ", second = " + second)
new CoGroupedRDD[K](
- Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]),
+ Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]),
part
).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
}
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
index 342610e1dd..5e84b3a66a 100644
--- a/core/src/test/scala/spark/DriverSuite.scala
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -9,10 +9,11 @@ import org.scalatest.time.SpanSugar._
class DriverSuite extends FunSuite with Timeouts {
test("driver should exit after finishing") {
+ assert(System.getenv("SPARK_HOME") != null)
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
forAll(masters) { (master: String) =>
- failAfter(10 seconds) {
+ failAfter(30 seconds) {
Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master),
new File(System.getenv("SPARK_HOME")))
}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 934e4c2f67..9ffe7c5f99 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -696,4 +696,28 @@ public class JavaAPISuite implements Serializable {
JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
}
+
+ @Test
+ public void mapOnPairRDD() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4));
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+ JavaPairRDD<Integer, Integer> rdd3 = rdd2.map(
+ new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception {
+ return new Tuple2<Integer, Integer>(in._2(), in._1());
+ }
+ });
+ Assert.assertEquals(Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(0, 2),
+ new Tuple2<Integer, Integer>(1, 3),
+ new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
+
+ }
}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index f4e7ec39fe..dd19442dcb 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -79,8 +79,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("remote fetch") {
try {
System.clearProperty("spark.driver.host") // In case some previous test had set it
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("test", "localhost", 0)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
System.setProperty("spark.driver.port", boundPort.toString)
val masterTracker = new MapOutputTracker(actorSystem, true)
val slaveTracker = new MapOutputTracker(actorSystem, false)
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index ed03e65153..9739ba869b 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -12,9 +12,10 @@ class RDDSuite extends FunSuite with LocalSparkContext {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
- assert(dups.distinct.count === 4)
- assert(dups.distinct().collect === dups.distinct.collect)
- assert(dups.distinct(2).collect === dups.distinct.collect)
+ assert(dups.distinct().count() === 4)
+ assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
+ assert(dups.distinct.collect === dups.distinct().collect)
+ assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -31,6 +32,15 @@ class RDDSuite extends FunSuite with LocalSparkContext {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+
+ val partitionSumsWithIndex = nums.mapPartitionsWithIndex {
+ case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
+ }
+ assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7)))
+
+ intercept[UnsupportedOperationException] {
+ nums.filter(_ > 5).reduce(_ + _)
+ }
}
test("SparkContext.union") {
@@ -92,12 +102,12 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("caching with failures") {
sc = new SparkContext("local", "test")
- val onlySplit = new Split { override def index: Int = 0 }
+ val onlySplit = new Partition { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
- override def getSplits: Array[Split] = Array(onlySplit)
+ override def getPartitions: Array[Partition] = Array(onlySplit)
override val getDependencies = List[Dependency[_]]()
- override def compute(split: Split, context: TaskContext): Iterator[Int] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
if (shouldFail) {
throw new Exception("injected failure")
} else {
@@ -117,7 +127,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
- val coalesced1 = new CoalescedRDD(data, 2)
+ val coalesced1 = data.coalesce(2)
assert(coalesced1.collect().toList === (1 to 10).toList)
assert(coalesced1.glom().collect().map(_.toList).toList ===
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
@@ -128,19 +138,19 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
List(5, 6, 7, 8, 9))
- val coalesced2 = new CoalescedRDD(data, 3)
+ val coalesced2 = data.coalesce(3)
assert(coalesced2.collect().toList === (1 to 10).toList)
assert(coalesced2.glom().collect().map(_.toList).toList ===
List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10)))
- val coalesced3 = new CoalescedRDD(data, 10)
+ val coalesced3 = data.coalesce(10)
assert(coalesced3.collect().toList === (1 to 10).toList)
assert(coalesced3.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList)
// If we try to coalesce into more partitions than the original RDD, it should just
// keep the original number of partitions.
- val coalesced4 = new CoalescedRDD(data, 20)
+ val coalesced4 = data.coalesce(20)
assert(coalesced4.collect().toList === (1 to 10).toList)
assert(coalesced4.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList)
@@ -163,8 +173,8 @@ class RDDSuite extends FunSuite with LocalSparkContext {
val data = sc.parallelize(1 to 10, 10)
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
- assert(prunedRdd.splits.size === 1)
- val prunedData = prunedRdd.collect
+ assert(prunedRdd.partitions.size === 1)
+ val prunedData = prunedRdd.collect()
assert(prunedData.size === 1)
assert(prunedData(0) === 10)
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 3493b9511f..92c3f67416 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -1,6 +1,7 @@
package spark
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
@@ -98,6 +99,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
+
+ test("reduceByKey with partitioner") {
+ sc = new SparkContext("local", "test")
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
test("join") {
sc = new SparkContext("local", "test")
@@ -199,7 +222,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.splits.size == 0)
+ assert(file.partitions.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index edb8c839fc..495f957e53 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -19,7 +19,7 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey()
- assert(sorted.splits.size === 2)
+ assert(sorted.partitions.size === 2)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
@@ -29,17 +29,17 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 1)
- assert(sorted.splits.size === 1)
+ assert(sorted.partitions.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
- test("large array with many splits") {
+ test("large array with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 20)
- assert(sorted.splits.size === 20)
+ assert(sorted.partitions.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
@@ -59,7 +59,7 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
- test("sort descending with many splits") {
+ test("sort descending with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
diff --git a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala
index 450c69bd58..d27a2538e4 100644
--- a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
+++ b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import scala.collection.immutable.NumericRange
@@ -11,7 +11,7 @@ import org.scalacheck.Prop._
class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("one element per slice") {
val data = Array(1, 2, 3)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1")
assert(slices(1).mkString(",") === "2")
@@ -20,14 +20,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("one slice") {
val data = Array(1, 2, 3)
- val slices = ParallelCollection.slice(data, 1)
+ val slices = ParallelCollectionRDD.slice(data, 1)
assert(slices.size === 1)
assert(slices(0).mkString(",") === "1,2,3")
}
test("equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -36,7 +36,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("non-equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -45,7 +45,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("splitting exclusive range") {
val data = 0 until 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 65).mkString(","))
@@ -54,7 +54,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("splitting inclusive range") {
val data = 0 to 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 66).mkString(","))
@@ -63,24 +63,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("empty data") {
val data = new Array[Int](0)
- val slices = ParallelCollection.slice(data, 5)
+ val slices = ParallelCollectionRDD.slice(data, 5)
assert(slices.size === 5)
for (slice <- slices) assert(slice.size === 0)
}
test("zero slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelCollection.slice(data, 0) }
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) }
}
test("negative number of slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelCollection.slice(data, -5) }
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) }
}
test("exclusive ranges sliced into ranges") {
val data = 1 until 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -88,7 +88,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges sliced into ranges") {
val data = 1 to 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -97,7 +97,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("large ranges don't overflow") {
val N = 100 * 1000 * 1000
val data = 0 until N
- val slices = ParallelCollection.slice(data, 40)
+ val slices = ParallelCollectionRDD.slice(data, 40)
assert(slices.size === 40)
for (i <- 0 until 40) {
assert(slices(i).isInstanceOf[Range])
@@ -117,7 +117,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
(tuple: (List[Int], Int)) =>
val d = tuple._1
val n = tuple._2
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
@@ -134,7 +134,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
} yield (a until b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
@@ -152,7 +152,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
} yield (a to b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
@@ -163,7 +163,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("exclusive ranges of longs") {
val data = 1L until 100L
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -171,7 +171,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges of longs") {
val data = 1L to 100L
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -179,7 +179,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("exclusive ranges of doubles") {
val data = 1.0 until 100.0 by 1.0
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -187,7 +187,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges of doubles") {
val data = 1.0 to 100.0 by 1.0
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..8de490eb86
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,663 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.TimeLimitedTests
+import org.scalatest.mock.EasyMockSugar
+import org.scalatest.time.{Span, Seconds}
+
+import org.easymock.EasyMock._
+import org.easymock.Capture
+import org.easymock.EasyMock
+import org.easymock.{IAnswer, IArgumentMatcher}
+
+import akka.actor.ActorSystem
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Partition
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+
+ // impose a time limit on this test in case we don't let the job finish, in which case
+ // JobWaiter#getResult will hang.
+ override val timeLimit = Span(5, Seconds)
+
+ val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
+ var scheduler: DAGScheduler = null
+ val taskScheduler = mock[TaskScheduler]
+ val blockManagerMaster = mock[BlockManagerMaster]
+ var mapOutputTracker: MapOutputTracker = null
+ var schedulerThread: Thread = null
+ var schedulerException: Throwable = null
+
+ /**
+ * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+ * We cache these so we do not create duplicate matchers for the same RDD.
+ * This allows us to easily setup a sequence of expectations for task sets for
+ * that RDD.
+ */
+ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+
+ /**
+ * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+ * will only submit one job) from needing to explicitly track it.
+ */
+ var lastJobWaiter: JobWaiter[Int] = null
+
+ /**
+ * Array into which we are accumulating the results from the last job asynchronously.
+ */
+ var lastJobResult: Array[Int] = null
+
+ /**
+ * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+ * and whenExecuting {...} */
+ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+
+ /**
+ * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+ * to be reset after each time their expectations are set, and we tend to check mock object
+ * calls over a single call to DAGScheduler.
+ *
+ * We also set a default expectation here that blockManagerMaster.getLocations can be called
+ * and will return values from cacheLocations.
+ */
+ def resetExpecting(f: => Unit) {
+ reset(taskScheduler)
+ reset(blockManagerMaster)
+ expecting {
+ expectGetLocations()
+ f
+ }
+ }
+
+ before {
+ taskSetMatchers.clear()
+ cacheLocations.clear()
+ val actorSystem = ActorSystem("test")
+ mapOutputTracker = new MapOutputTracker(actorSystem, true)
+ resetExpecting {
+ taskScheduler.setListener(anyObject())
+ }
+ whenExecuting {
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+ }
+ }
+
+ after {
+ assert(scheduler.processEvent(StopDAGScheduler))
+ resetExpecting {
+ taskScheduler.stop()
+ }
+ whenExecuting {
+ scheduler.stop()
+ }
+ sc.stop()
+ System.clearProperty("spark.master.port")
+ }
+
+ def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345)
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+ * is from a particular RDD.
+ */
+ def taskSetForRdd(rdd: MyRDD): TaskSet = {
+ val matcher = taskSetMatchers.getOrElseUpdate(rdd,
+ new IArgumentMatcher {
+ override def matches(actual: Any): Boolean = {
+ val taskSet = actual.asInstanceOf[TaskSet]
+ taskSet.tasks(0) match {
+ case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
+ case smt: ShuffleMapTask => smt.rdd.id == rdd.id
+ case _ => false
+ }
+ }
+ override def appendTo(buf: StringBuffer) {
+ buf.append("taskSetForRdd(" + rdd + ")")
+ }
+ })
+ EasyMock.reportMatcher(matcher)
+ return null
+ }
+
+ /**
+ * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+ * cacheLocations.
+ */
+ def expectGetLocations(): Unit = {
+ EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
+ andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
+ override def answer(): Seq[Seq[BlockManagerId]] = {
+ val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
+ return blocks.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ if (cacheLocations.contains(key)) {
+ cacheLocations(key)
+ } else {
+ Seq[BlockManagerId]()
+ }
+ } else {
+ Seq[BlockManagerId]()
+ }
+ }.toSeq
+ }
+ }).anyTimes()
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+ * called from a resetExpecting { ... } block.
+ *
+ * Returns a easymock Capture that will contain the task set after the stage is submitted.
+ * Most tests should use interceptStage() instead of this directly.
+ */
+ def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+ val taskSetCapture = new Capture[TaskSet]
+ taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+ return taskSetCapture
+ }
+
+ /**
+ * Expect the supplied code snippet to submit a stage for the specified RDD.
+ * Return the resulting TaskSet. First marks all the tasks are belonging to the
+ * current MapOutputTracker generation.
+ */
+ def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+ var capture: Capture[TaskSet] = null
+ resetExpecting {
+ capture = expectStage(rdd)
+ }
+ whenExecuting {
+ f
+ }
+ val taskSet = capture.getValue
+ for (task <- taskSet.tasks) {
+ task.generation = mapOutputTracker.getGeneration
+ }
+ return taskSet
+ }
+
+ /**
+ * Send the given CompletionEvent messages for the tasks in the TaskSet.
+ */
+ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+ }
+ }
+ }
+
+ /**
+ * Assert that the supplied TaskSet has exactly the given preferredLocations.
+ */
+ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+ assert(locations.size === taskSet.tasks.size)
+ for ((expectLocs, taskLocs) <-
+ taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+ assert(expectLocs === taskLocs)
+ }
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+ it.next._1.asInstanceOf[Int]
+
+
+ /**
+ * Start a job to compute the given RDD. Returns the JobWaiter that will
+ * collect the result of the job via callbacks from DAGScheduler.
+ */
+ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
+ val resultArray = new Array[Int](rdd.partitions.size)
+ val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
+ rdd,
+ jobComputeFunc,
+ (0 to (rdd.partitions.size - 1)),
+ "test-site",
+ allowLocal,
+ (i: Int, value: Int) => resultArray(i) = value
+ )
+ lastJobWaiter = waiter
+ lastJobResult = resultArray
+ runEvent(toSubmit)
+ return (waiter, resultArray)
+ }
+
+ /**
+ * Assert that a job we started has failed.
+ */
+ def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
+ waiter.awaitResult() match {
+ case JobSucceeded => fail()
+ case JobFailed(_) => return
+ }
+ }
+
+ /**
+ * Assert that a job we started has succeeded and has the given result.
+ */
+ def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
+ result: Array[Int] = lastJobResult) {
+ waiter.awaitResult match {
+ case JobSucceeded =>
+ assert(expected === result)
+ case JobFailed(_) =>
+ fail()
+ }
+ }
+
+ def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ def accumulateResult(partition: Int, value: Int) {
+ numResults += 1
+ }
+ scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getPartitions = Array( new Partition { override def index = 0 } )
+ override def getPreferredLocations(split: Partition) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ submitRdd(rdd, true)
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("trivial job failure") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ runEvent(TaskSetFailed(taskSet, "test failure"))
+ expectJobException()
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(secondStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(secondStage, List(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
+ ))
+ }
+ val thirdStage = interceptStage(shuffleMapRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val fourthStage = interceptStage(reduceRdd) {
+ respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(fourthStage, List( (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val oldGeneration = mapOutputTracker.getGeneration
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ val newGeneration = mapOutputTracker.getGeneration
+ assert(newGeneration > oldGeneration)
+ val noAccum = Map[Long, Any]()
+ // We rely on the event queue being ordered and increasing the generation number by 1
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ taskSet.tasks(1).generation = newGeneration
+ val secondStage = interceptStage(reduceRdd) {
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ val secondStage = interceptStage(shuffleMapRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ val thirdStage = interceptStage(reduceRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ respondToTaskSet(thirdStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ val secondStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val thirdStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(thirdStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List(
+ (Success, makeMapStatus("hostA", 2))
+ ))
+ }
+ val finalStage = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostA", 1))
+ ))
+ }
+ respondToTaskSet(finalStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostD", 1))
+ ))
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle but fails") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+ intercept[FetchFailedException]{
+ mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+ }
+
+ // Simulate the shuffle input data failing to be cached.
+ cacheLocations.remove(shuffleTwoRdd.id -> 0)
+ respondToTaskSet(recomputeTwoCached, List(
+ (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+ ))
+
+ // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+ // everything.
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+ val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
index a5db7103f5..647bcaf860 100644
--- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -5,7 +5,7 @@ import org.scalatest.BeforeAndAfter
import spark.TaskContext
import spark.RDD
import spark.SparkContext
-import spark.Split
+import spark.Partition
import spark.LocalSparkContext
class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
@@ -14,8 +14,8 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
var completed = false
sc = new SparkContext("local", "test")
val rdd = new RDD[String](sc, List()) {
- override def getSplits = Array[Split](StubSplit(0))
- override def compute(split: Split, context: TaskContext) = {
+ override def getPartitions = Array[Partition](StubPartition(0))
+ override def compute(split: Partition, context: TaskContext) = {
context.addOnCompleteCallback(() => completed = true)
sys.error("failed")
}
@@ -28,5 +28,5 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
assert(completed === true)
}
- case class StubSplit(val index: Int) extends Split
-} \ No newline at end of file
+ case class StubPartition(val index: Int) extends Partition
+}