aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2012-12-13 16:48:34 -0800
committerReynold Xin <rxin@cs.berkeley.edu>2012-12-13 16:48:34 -0800
commitdc7d7fc28656948d94000e6532f1437d597cd4db (patch)
tree3cddc75943507183242c3273b1e7804219ee76df
parent1b7a0451ed7df78838ca7ea09dfa5ba0e236acfe (diff)
parentdfcbb2ad9720e00ad0f86c72fce54c6c84b01260 (diff)
downloadspark-dc7d7fc28656948d94000e6532f1437d597cd4db.tar.gz
spark-dc7d7fc28656948d94000e6532f1437d597cd4db.tar.bz2
spark-dc7d7fc28656948d94000e6532f1437d597cd4db.zip
Merge branch 'master' of github.com:mesos/spark into spark-633
-rw-r--r--bagel/pom.xml3
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala29
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala15
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala17
-rw-r--r--core/src/main/scala/spark/RDD.scala8
-rw-r--r--core/src/main/scala/spark/TaskContext.scala19
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala25
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala25
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala17
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala30
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala9
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala7
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala8
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala22
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala9
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala7
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala8
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala34
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala15
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala9
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala22
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala19
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala17
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala17
-rwxr-xr-xec2/spark_ec2.py78
-rw-r--r--examples/pom.xml3
29 files changed, 268 insertions, 231 deletions
diff --git a/bagel/pom.xml b/bagel/pom.xml
index b462801589..a8256a6e8b 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -55,6 +55,7 @@
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
</dependency>
</dependencies>
<build>
@@ -81,10 +82,12 @@
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
</dependency>
</dependencies>
<build>
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index c5db6ce63a..3d79078733 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -1,5 +1,9 @@
package spark
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
@@ -8,10 +12,6 @@ import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
import spark.storage.BlockManager
import spark.storage.StorageLevel
@@ -41,7 +41,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
-
+
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
slaveCapacity.put(host, size)
@@ -92,14 +92,14 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
-
+
// Tracker actor on the master, or remote reference to it on workers
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "CacheTracker"
val timeout = 10.seconds
-
+
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
logInfo("Registered CacheTrackerActor actor")
@@ -132,7 +132,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
throw new SparkException("Error reply received from CacheTracker")
}
}
-
+
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
@@ -143,7 +143,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
}
}
}
-
+
// For BlockManager.scala only
def cacheLost(host: String) {
communicate(MemoryCacheLost(host))
@@ -155,19 +155,20 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
def getCacheStatus(): Seq[(String, Long, Long)] = {
askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
}
-
+
// For BlockManager.scala only
def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
}
-
+
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
}
-
+
// Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
+ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ : Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
@@ -209,7 +210,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
val elements = new ArrayBuffer[Any]
- elements ++= rdd.compute(split)
+ elements ++= rdd.compute(split, context)
try {
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index e5bb639cfd..08ae06e865 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -35,11 +35,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
with Serializable {
/**
- * Generic function to combine the elements for each key using a custom set of aggregation
+ * Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
* Note that V and C can be different -- for example, one might group an RDD of type
* (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions:
- *
+ *
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
@@ -118,7 +118,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
- /**
+ /**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
@@ -224,7 +224,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
- /**
+ /**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
* parallelism level.
*/
@@ -628,7 +628,8 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)]
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner
- override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))}
+ override def compute(split: Split, taskContext: TaskContext) =
+ prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))}
}
private[spark]
@@ -639,8 +640,8 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]
override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner
- override def compute(split: Split) = {
- prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) }
+ override def compute(split: Split, taskContext: TaskContext) = {
+ prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index 9b57ae3b4f..a27f766e31 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -8,8 +8,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
val slice: Int,
values: Seq[T])
extends Split with Serializable {
-
- def iterator(): Iterator[T] = values.iterator
+
+ def iterator: Iterator[T] = values.iterator
override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
@@ -22,7 +22,7 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
- sc: SparkContext,
+ sc: SparkContext,
@transient data: Seq[T],
numSlices: Int)
extends RDD[T](sc) {
@@ -38,17 +38,18 @@ private[spark] class ParallelCollection[T: ClassManifest](
override def splits = splits_.asInstanceOf[Array[Split]]
- override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator
-
+ override def compute(s: Split, taskContext: TaskContext) =
+ s.asInstanceOf[ParallelCollectionSplit[T]].iterator
+
override def preferredLocations(s: Split): Seq[String] = Nil
-
+
override val dependencies: List[Dependency[_]] = Nil
}
private object ParallelCollection {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
- * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
+ * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* it efficient to run Spark over RDDs representing large sets of numbers.
*/
def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
@@ -58,7 +59,7 @@ private object ParallelCollection {
seq match {
case r: Range.Inclusive => {
val sign = if (r.step < 0) {
- -1
+ -1
} else {
1
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 6270e018b3..bb4c13c494 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -81,7 +81,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def splits: Array[Split]
/** Function for computing a given partition. */
- def compute(split: Split): Iterator[T]
+ def compute(split: Split, context: TaskContext): Iterator[T]
/** How this RDD depends on any parent RDDs. */
@transient val dependencies: List[Dependency[_]]
@@ -155,11 +155,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
- final def iterator(split: Split): Iterator[T] = {
+ final def iterator(split: Split, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
+ SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
} else {
- compute(split)
+ compute(split, context)
}
}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index c14377d17b..d2746b26b3 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -1,3 +1,20 @@
package spark
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable
+import scala.collection.mutable.ArrayBuffer
+
+
+class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
+
+ @transient
+ val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+
+ // Add a callback function to be executed on task completion. An example use
+ // is for HadoopRDD to register a callback to close the input stream.
+ def addOnCompleteCallback(f: () => Unit) {
+ onCompleteCallbacks += f
+ }
+
+ def executeOnCompleteCallbacks() {
+ onCompleteCallbacks.foreach{_()}
+ }
+}
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 482eb9281a..81d3a94466 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -1,16 +1,15 @@
package spark.api.java
-import spark.{SparkContext, Split, RDD}
+import java.util.{List => JList}
+import scala.Tuple2
+import scala.collection.JavaConversions._
+
+import spark.{SparkContext, Split, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import spark.partial.{PartialResult, BoundedDouble}
import spark.storage.StorageLevel
-import java.util.{List => JList}
-
-import scala.collection.JavaConversions._
-import java.{util, lang}
-import scala.Tuple2
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
@@ -24,7 +23,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/** The [[spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
-
+
/** A unique ID for this RDD (within its SparkContext). */
def id: Int = rdd.id
@@ -36,7 +35,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
- def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split))
+ def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] =
+ asJavaIterator(rdd.iterator(split, taskContext))
// Transformations (return a new RDD)
@@ -99,7 +99,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
}
-
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
@@ -183,7 +182,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
// Actions (launch a job to return a value to the user program)
-
+
/**
* Applies a function f to all elements of this RDD.
*/
@@ -200,7 +199,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
val arr: java.util.Collection[T] = rdd.collect().toSeq
new java.util.ArrayList(arr)
}
-
+
/**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
@@ -208,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
- * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
+ * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
* modify t2.
*/
@@ -251,7 +250,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
- mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2)))))
+ mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
/**
* (Experimental) Approximate version of countByValue().
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index cb73976aed..f98528a183 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -2,11 +2,8 @@ package spark.rdd
import scala.collection.mutable.HashMap
-import spark.Dependency
-import spark.RDD
-import spark.SparkContext
-import spark.SparkEnv
-import spark.Split
+import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext}
+
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
@@ -19,29 +16,29 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
@transient
val splits_ = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
- }).toArray
-
- @transient
+ }).toArray
+
+ @transient
lazy val locations_ = {
- val blockManager = SparkEnv.get.blockManager
+ val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
- val locations = blockManager.getLocations(blockIds)
+ val locations = blockManager.getLocations(blockIds)
HashMap(blockIds.zip(locations):_*)
}
override def splits = splits_
- override def compute(split: Split): Iterator[T] = {
- val blockManager = SparkEnv.get.blockManager
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
+ val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDSplit].blockId
blockManager.get(blockId) match {
case Some(block) => block.asInstanceOf[Iterator[T]]
- case None =>
+ case None =>
throw new Exception("Could not compute split, block " + blockId + " not found")
}
}
- override def preferredLocations(split: Split) =
+ override def preferredLocations(split: Split) =
locations_(split.asInstanceOf[BlockRDDSplit].blockId)
override val dependencies: List[Dependency[_]] = Nil
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 7c354b6b2e..4a7e5f3d06 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,9 +1,7 @@
package spark.rdd
-import spark.NarrowDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext}
+
private[spark]
class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
@@ -17,9 +15,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
rdd2: RDD[U])
extends RDD[Pair[T, U]](sc)
with Serializable {
-
+
val numSplitsInRdd2 = rdd2.splits.size
-
+
@transient
val splits_ = {
// create the cross product split
@@ -38,11 +36,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
- override def compute(split: Split) = {
+ override def compute(split: Split, context: TaskContext) = {
val currSplit = split.asInstanceOf[CartesianSplit]
- for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y)
+ for (x <- rdd1.iterator(currSplit.s1, context);
+ y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
-
+
override val dependencies = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 50bec9e63b..de0d9fad88 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -3,21 +3,15 @@ package spark.rdd
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.Aggregator
-import spark.Dependency
-import spark.Logging
-import spark.OneToOneDependency
-import spark.Partitioner
-import spark.RDD
-import spark.ShuffleDependency
-import spark.SparkEnv
-import spark.Split
+import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
+import spark.{Dependency, OneToOneDependency, ShuffleDependency}
+
private[spark] sealed trait CoGroupSplitDep extends Serializable
private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
-private[spark]
+private[spark]
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
@@ -32,9 +26,9 @@ private[spark] class CoGroupAggregator
class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
-
+
val aggr = new CoGroupAggregator
-
+
@transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
@@ -50,7 +44,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
deps.toList
}
-
+
@transient
val splits_ : Array[Split] = {
val firstRdd = rdds.head
@@ -69,12 +63,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
override def splits = splits_
-
+
override val partitioner = Some(part)
-
+
override def preferredLocations(s: Split) = Nil
-
- override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
+
+ override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
@@ -84,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
// Read them from the parent
- for ((k, v) <- rdd.iterator(itsSplit)) {
+ for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 0967f4f5df..1affe0e0ef 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,8 +1,7 @@
package spark.rdd
-import spark.NarrowDependency
-import spark.RDD
-import spark.Split
+import spark.{NarrowDependency, RDD, Split, TaskContext}
+
private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
@@ -32,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
override def splits = splits_
- override def compute(split: Split): Iterator[T] = {
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
- parentSplit => prev.iterator(parentSplit)
+ parentSplit => prev.iterator(parentSplit, context)
}
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index dfe9dc73f3..b148da28de 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -1,12 +1,11 @@
package spark.rdd
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
private[spark]
class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).filter(f)
+ override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f)
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 3534dc8057..785662b2da 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,16 +1,16 @@
package spark.rdd
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => TraversableOnce[U])
extends RDD[U](prev.context) {
-
+
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).flatMap(f)
+
+ override def compute(split: Split, context: TaskContext) =
+ prev.iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index e30564f2da..fac8ffb4cb 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,12 +1,12 @@
package spark.rdd
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
private[spark]
class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
+ override def compute(split: Split, context: TaskContext) =
+ Array(prev.iterator(split, context).toArray).iterator
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index bf29a1f075..ab163f569b 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -15,19 +15,16 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import spark.Dependency
-import spark.RDD
-import spark.SerializableWritable
-import spark.SparkContext
-import spark.Split
+import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
-/**
+
+/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
extends Split
with Serializable {
-
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -47,10 +44,10 @@ class HadoopRDD[K, V](
valueClass: Class[V],
minSplits: Int)
extends RDD[(K, V)](sc) {
-
+
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
-
+
@transient
val splits_ : Array[Split] = {
val inputFormat = createInputFormat(conf)
@@ -69,7 +66,7 @@ class HadoopRDD[K, V](
override def splits = splits_
- override def compute(theSplit: Split) = new Iterator[(K, V)] {
+ override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopSplit]
var reader: RecordReader[K, V] = null
@@ -77,6 +74,9 @@ class HadoopRDD[K, V](
val fmt = createInputFormat(conf)
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => reader.close())
+
val key: K = reader.createKey()
val value: V = reader.createValue()
var gotNext = false
@@ -115,6 +115,6 @@ class HadoopRDD[K, V](
val hadoopSplit = split.asInstanceOf[HadoopSplit]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
-
+
override val dependencies: List[Dependency[_]] = Nil
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index a904ef62c3..c764505345 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,8 +1,7 @@
package spark.rdd
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
private[spark]
class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
@@ -12,8 +11,8 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
extends RDD[U](prev.context) {
override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
+
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = f(prev.iterator(split))
+ override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index 14e390c43b..3d9888bd34 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -1,8 +1,6 @@
package spark.rdd
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
/**
* A variant of the MapPartitionsRDD that passes the split index into the
@@ -19,5 +17,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = f(split.index, prev.iterator(split))
+ override def compute(split: Split, context: TaskContext) =
+ f(split.index, prev.iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 59bedad8ef..70fa8f4497 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,16 +1,14 @@
package spark.rdd
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => U)
extends RDD[U](prev.context) {
-
+
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).map(f)
+ override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f)
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 7a1a0fb87d..197ed5ea17 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -1,22 +1,19 @@
package spark.rdd
+import java.text.SimpleDateFormat
+import java.util.Date
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import java.util.Date
-import java.text.SimpleDateFormat
+import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
-import spark.Dependency
-import spark.RDD
-import spark.SerializableWritable
-import spark.SparkContext
-import spark.Split
-private[spark]
+private[spark]
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
-
+
val serializableHadoopSplit = new SerializableWritable(rawSplit)
override def hashCode(): Int = (41 * (41 + rddId) + index)
@@ -29,7 +26,7 @@ class NewHadoopRDD[K, V](
@transient conf: Configuration)
extends RDD[(K, V)](sc)
with HadoopMapReduceUtil {
-
+
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
@@ -56,15 +53,19 @@ class NewHadoopRDD[K, V](
override def splits = splits_
- override def compute(theSplit: Split) = new Iterator[(K, V)] {
+ override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val context = newTaskAttemptContext(conf, attemptId)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
- val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
- reader.initialize(split.serializableHadoopSplit.value, context)
-
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => reader.close())
+
var havePair = false
var finished = false
@@ -72,9 +73,6 @@ class NewHadoopRDD[K, V](
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
- if (finished) {
- reader.close()
- }
}
!finished
}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 98ea0c92d6..336e193217 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -8,10 +8,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
-import spark.OneToOneDependency
-import spark.RDD
-import spark.SparkEnv
-import spark.Split
+import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
/**
@@ -32,12 +29,12 @@ class PipedRDD[T: ClassManifest](
override val dependencies = List(new OneToOneDependency(parent))
- override def compute(split: Split): Iterator[String] = {
+ override def compute(split: Split, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
-
+
val proc = pb.start()
val env = SparkEnv.get
@@ -55,7 +52,7 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
- for (elem <- parent.iterator(split)) {
+ for (elem <- parent.iterator(split, context)) {
out.println(elem)
}
out.close()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 87a5268f27..6e4797aabb 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -4,9 +4,8 @@ import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
-import spark.RDD
-import spark.OneToOneDependency
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
@@ -15,7 +14,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
class SampledRDD[T: ClassManifest](
prev: RDD[T],
- withReplacement: Boolean,
+ withReplacement: Boolean,
frac: Double,
seed: Int)
extends RDD[T](prev.context) {
@@ -29,17 +28,17 @@ class SampledRDD[T: ClassManifest](
override def splits = splits_.asInstanceOf[Array[Split]]
override val dependencies = List(new OneToOneDependency(prev))
-
+
override def preferredLocations(split: Split) =
prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
- override def compute(splitIn: Split) = {
+ override def compute(splitIn: Split, context: TaskContext) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
- prev.iterator(split.prev).flatMap { element =>
+ prev.iterator(split.prev, context).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
@@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest](
}
} else { // Sampling without replacement
val rand = new Random(split.seed)
- prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
+ prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 145e419c53..f832633646 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,10 +1,7 @@
package spark.rdd
-import spark.Partitioner
-import spark.RDD
-import spark.ShuffleDependency
-import spark.SparkEnv
-import spark.Split
+import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
+
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
@@ -34,7 +31,7 @@ class ShuffledRDD[K, V](
val dep = new ShuffleDependency(parent, part)
override val dependencies = List(dep)
- override def compute(split: Split): Iterator[(K, V)] = {
+ override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
}
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index f0b9225f7c..a08473f7be 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -2,20 +2,17 @@ package spark.rdd
import scala.collection.mutable.ArrayBuffer
-import spark.Dependency
-import spark.RangeDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+
private[spark] class UnionSplit[T: ClassManifest](
- idx: Int,
+ idx: Int,
rdd: RDD[T],
split: Split)
extends Split
with Serializable {
-
- def iterator() = rdd.iterator(split)
+
+ def iterator(context: TaskContext) = rdd.iterator(split, context)
def preferredLocations() = rdd.preferredLocations(split)
override val index: Int = idx
}
@@ -25,7 +22,7 @@ class UnionRDD[T: ClassManifest](
@transient rdds: Seq[RDD[T]])
extends RDD[T](sc)
with Serializable {
-
+
@transient
val splits_ : Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
@@ -44,13 +41,14 @@ class UnionRDD[T: ClassManifest](
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
- deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
+ deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
deps.toList
}
-
- override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator()
+
+ override def compute(s: Split, context: TaskContext): Iterator[T] =
+ s.asInstanceOf[UnionSplit[T]].iterator(context)
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 80f0150c45..92d667ff1e 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,21 +1,19 @@
package spark.rdd
-import spark.Dependency
-import spark.OneToOneDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+
private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
- idx: Int,
+ idx: Int,
rdd1: RDD[T],
rdd2: RDD[U],
split1: Split,
split2: Split)
extends Split
with Serializable {
-
- def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2))
+
+ def iterator(context: TaskContext): Iterator[(T, U)] =
+ rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
def preferredLocations(): Seq[String] =
rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
@@ -46,8 +44,9 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
@transient
override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
-
- override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator()
+
+ override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
+ s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 5c71207d43..29757b1178 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -16,8 +16,8 @@ import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
/**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
@@ -73,7 +73,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
// that's not going to be a realistic assumption in general
-
+
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
@@ -94,7 +94,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
}
-
+
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
}
@@ -326,7 +326,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- val result = job.func(taskContext, rdd.iterator(split))
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ taskContext.executeOnCompleteCallbacks()
job.listener.taskSucceeded(0, result)
} catch {
case e: Exception =>
@@ -353,7 +354,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
}
-
+
def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
@@ -395,7 +396,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val task = event.task
val stage = idToStage(task.stageId)
event.reason match {
- case Success =>
+ case Success =>
logInfo("Completed " + task)
if (event.accumUpdates != null) {
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
@@ -519,7 +520,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
updateCacheLocs()
}
}
-
+
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 2ebd4075a2..e492279b4e 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -10,12 +10,14 @@ private[spark] class ResultTask[T, U](
@transient locs: Seq[String],
val outputId: Int)
extends Task[U](stageId) {
-
+
val split = rdd.splits(partition)
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
- func(context, rdd.iterator(split))
+ val result = func(context, rdd.iterator(split, context))
+ context.executeOnCompleteCallbacks()
+ result
}
override def preferredLocations: Seq[String] = locs
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 60105c42b6..bd1911fce2 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -70,19 +70,19 @@ private[spark] object ShuffleMapTask {
private[spark] class ShuffleMapTask(
stageId: Int,
- var rdd: RDD[_],
+ var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
- var partition: Int,
+ var partition: Int,
@transient var locs: Seq[String])
extends Task[MapStatus](stageId)
with Externalizable
with Logging {
def this() = this(0, null, null, 0, null)
-
+
var split = if (rdd == null) {
- null
- } else {
+ null
+ } else {
rdd.splits(partition)
}
@@ -113,9 +113,11 @@ private[spark] class ShuffleMapTask(
val numOutputSplits = dep.partitioner.numPartitions
val partitioner = dep.partitioner
+ val taskContext = new TaskContext(stageId, partition, attemptId)
+
// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
- for (elem <- rdd.iterator(split)) {
+ for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(pair._1)
buckets(bucketId) += pair
@@ -133,6 +135,9 @@ private[spark] class ShuffleMapTask(
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
+ // Execute the callbacks on task completion.
+ taskContext.executeOnCompleteCallbacks()
+
return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 32a896e5a4..a5384d3bda 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -30,6 +30,7 @@ import time
import urllib2
from optparse import OptionParser
from sys import stderr
+import boto
from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType
from boto import ec2
@@ -85,6 +86,8 @@ def parse_args():
help="'mesos' for a mesos cluster, 'standalone' for a standalone spark cluster (default: mesos)")
parser.add_option("-u", "--user", default="root",
help="The ssh user you want to connect as (default: root)")
+ parser.add_option("--delete-groups", action="store_true", default=False,
+ help="When destroying a cluster, also destroy the security groups that were created")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -283,16 +286,17 @@ def launch_cluster(conn, opts, cluster_name):
slave_nodes = []
for zone in zones:
num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
- slave_res = image.run(key_name = opts.key_pair,
- security_groups = [slave_group],
- instance_type = opts.instance_type,
- placement = zone,
- min_count = num_slaves_this_zone,
- max_count = num_slaves_this_zone,
- block_device_map = block_map)
- slave_nodes += slave_res.instances
- print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone,
- zone, slave_res.id)
+ if num_slaves_this_zone > 0:
+ slave_res = image.run(key_name = opts.key_pair,
+ security_groups = [slave_group],
+ instance_type = opts.instance_type,
+ placement = zone,
+ min_count = num_slaves_this_zone,
+ max_count = num_slaves_this_zone,
+ block_device_map = block_map)
+ slave_nodes += slave_res.instances
+ print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone,
+ zone, slave_res.id)
i += 1
# Launch masters
@@ -556,20 +560,48 @@ def main():
print "Terminating zoo..."
for inst in zoo_nodes:
inst.terminate()
+
# Delete security groups as well
- group_names = [cluster_name + "-master", cluster_name + "-slaves", cluster_name + "-zoo"]
- groups = conn.get_all_security_groups()
- for group in groups:
- if group.name in group_names:
- print "Deleting security group " + group.name
- # Delete individual rules before deleting group to remove dependencies
- for rule in group.rules:
- for grant in rule.grants:
- group.revoke(ip_protocol=rule.ip_protocol,
- from_port=rule.from_port,
- to_port=rule.to_port,
- src_group=grant)
- conn.delete_security_group(group.name)
+ if opts.delete_groups:
+ print "Deleting security groups (this will take some time)..."
+ group_names = [cluster_name + "-master", cluster_name + "-slaves", cluster_name + "-zoo"]
+
+ attempt = 1;
+ while attempt <= 3:
+ print "Attempt %d" % attempt
+ groups = [g for g in conn.get_all_security_groups() if g.name in group_names]
+ success = True
+ # Delete individual rules in all groups before deleting groups to
+ # remove dependencies between them
+ for group in groups:
+ print "Deleting rules in security group " + group.name
+ for rule in group.rules:
+ for grant in rule.grants:
+ success &= group.revoke(ip_protocol=rule.ip_protocol,
+ from_port=rule.from_port,
+ to_port=rule.to_port,
+ src_group=grant)
+
+ # Sleep for AWS eventual-consistency to catch up, and for instances
+ # to terminate
+ time.sleep(30) # Yes, it does have to be this long :-(
+ for group in groups:
+ try:
+ conn.delete_security_group(group.name)
+ print "Deleted security group " + group.name
+ except boto.exception.EC2ResponseError:
+ success = False;
+ print "Failed to delete security group " + group.name
+
+ # Unfortunately, group.revoke() returns True even if a rule was not
+ # deleted, so this needs to be rerun if something fails
+ if success: break;
+
+ attempt += 1
+
+ if not success:
+ print "Failed to delete all security groups after 3 tries."
+ print "Try re-running in a few minutes."
elif action == "login":
(master_nodes, slave_nodes, zoo_nodes) = get_existing_cluster(
diff --git a/examples/pom.xml b/examples/pom.xml
index d2643f046c..782c026d73 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -55,6 +55,7 @@
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
</dependency>
</dependencies>
<build>
@@ -81,10 +82,12 @@
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
</dependency>
</dependencies>
<build>