aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2012-12-13 16:41:15 -0800
committerReynold Xin <rxin@cs.berkeley.edu>2012-12-13 16:41:15 -0800
commit4f076e105ee30edcb1941216c79d017c5175d9b8 (patch)
tree6a6d7c87288440c5d5c9b9fa93ffc656bc8a5edd
parenteacb98e90075ca3082ad7c832b24719f322d9eb2 (diff)
downloadspark-4f076e105ee30edcb1941216c79d017c5175d9b8.tar.gz
spark-4f076e105ee30edcb1941216c79d017c5175d9b8.tar.bz2
spark-4f076e105ee30edcb1941216c79d017c5175d9b8.zip
SPARK-635: Pass a TaskContext object to compute() interface and use
that to close Hadoop input stream. Incorporated Matei's command.
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala5
-rw-r--r--core/src/main/scala/spark/RDD.scala8
-rw-r--r--core/src/main/scala/spark/TaskContext.scala4
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala3
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala3
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala3
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala8
20 files changed, 46 insertions, 49 deletions
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index e9c545a2cf..3d79078733 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -167,8 +167,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
}
// Gets or computes an RDD split
- def getOrCompute[T](
- rdd: RDD[T], split: Split, taskContext: TaskContext, storageLevel: StorageLevel)
+ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
@@ -211,7 +210,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
val elements = new ArrayBuffer[Any]
- elements ++= rdd.compute(split, taskContext)
+ elements ++= rdd.compute(split, context)
try {
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index c53eab67e5..bb4c13c494 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -81,7 +81,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def splits: Array[Split]
/** Function for computing a given partition. */
- def compute(split: Split, taskContext: TaskContext): Iterator[T]
+ def compute(split: Split, context: TaskContext): Iterator[T]
/** How this RDD depends on any parent RDDs. */
@transient val dependencies: List[Dependency[_]]
@@ -155,11 +155,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
- final def iterator(split: Split, taskContext: TaskContext): Iterator[T] = {
+ final def iterator(split: Split, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, taskContext, storageLevel)
+ SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
} else {
- compute(split, taskContext)
+ compute(split, context)
}
}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index b352db8167..d2746b26b3 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -6,11 +6,11 @@ import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
@transient
- val onCompleteCallbacks = new ArrayBuffer[Unit => Unit]
+ val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
- def registerOnCompleteCallback(f: Unit => Unit) {
+ def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index 8209c36871..f98528a183 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -28,7 +28,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
override def splits = splits_
- override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDSplit].blockId
blockManager.get(blockId) match {
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 6bc0938ce2..4a7e5f3d06 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -36,10 +36,10 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
- override def compute(split: Split, taskContext: TaskContext) = {
+ override def compute(split: Split, context: TaskContext) = {
val currSplit = split.asInstanceOf[CartesianSplit]
- for (x <- rdd1.iterator(currSplit.s1, taskContext);
- y <- rdd2.iterator(currSplit.s2, taskContext)) yield (x, y)
+ for (x <- rdd1.iterator(currSplit.s1, context);
+ y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
override val dependencies = List(
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 6037681cfd..de0d9fad88 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -68,7 +68,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def preferredLocations(s: Split) = Nil
- override def compute(s: Split, taskContext: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
@@ -78,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
// Read them from the parent
- for ((k, v) <- rdd.iterator(itsSplit, taskContext)) {
+ for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 06ffc9c42c..1affe0e0ef 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -31,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
override def splits = splits_
- override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
- parentSplit => prev.iterator(parentSplit, taskContext)
+ parentSplit => prev.iterator(parentSplit, context)
}
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index 14a80d82c7..b148da28de 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -7,6 +7,5 @@ private[spark]
class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, taskContext: TaskContext) =
- prev.iterator(split, taskContext).filter(f)
+ override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f)
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 64f8c51d6d..785662b2da 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -11,6 +11,6 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, taskContext: TaskContext) =
- prev.iterator(split, taskContext).flatMap(f)
+ override def compute(split: Split, context: TaskContext) =
+ prev.iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index d6b1b27d3e..fac8ffb4cb 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -7,6 +7,6 @@ private[spark]
class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, taskContext: TaskContext) =
- Array(prev.iterator(split, taskContext).toArray).iterator
+ override def compute(split: Split, context: TaskContext) =
+ Array(prev.iterator(split, context).toArray).iterator
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index c6c035a096..ab163f569b 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -66,7 +66,7 @@ class HadoopRDD[K, V](
override def splits = splits_
- override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
+ override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopSplit]
var reader: RecordReader[K, V] = null
@@ -75,7 +75,7 @@ class HadoopRDD[K, V](
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
- taskContext.registerOnCompleteCallback(Unit => reader.close())
+ context.addOnCompleteCallback(() => reader.close())
val key: K = reader.createKey()
val value: V = reader.createValue()
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index 715c240060..c764505345 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -14,6 +14,5 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, taskContext: TaskContext) =
- f(prev.iterator(split, taskContext))
+ override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index 39f3c7b5f7..3d9888bd34 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -17,6 +17,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, taskContext: TaskContext) =
- f(split.index, prev.iterator(split, taskContext))
+ override def compute(split: Split, context: TaskContext) =
+ f(split.index, prev.iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index d82ab3f671..70fa8f4497 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -10,6 +10,5 @@ class MappedRDD[U: ClassManifest, T: ClassManifest](
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, taskContext: TaskContext) =
- prev.iterator(split, taskContext).map(f)
+ override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f)
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 61f4cbbe94..197ed5ea17 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -53,17 +53,18 @@ class NewHadoopRDD[K, V](
override def splits = splits_
- override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
+ override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val context = newTaskAttemptContext(conf, attemptId)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
- val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
- reader.initialize(split.serializableHadoopSplit.value, context)
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
// Register an on-task-completion callback to close the input stream.
- taskContext.registerOnCompleteCallback(Unit => reader.close())
+ context.addOnCompleteCallback(() => reader.close())
var havePair = false
var finished = false
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index b34c7ea5b9..336e193217 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -29,7 +29,7 @@ class PipedRDD[T: ClassManifest](
override val dependencies = List(new OneToOneDependency(parent))
- override def compute(split: Split, taskContext: TaskContext): Iterator[String] = {
+ override def compute(split: Split, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
@@ -52,7 +52,7 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
- for (elem <- parent.iterator(split, taskContext)) {
+ for (elem <- parent.iterator(split, context)) {
out.println(elem)
}
out.close()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 07a1487f3a..6e4797aabb 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -32,13 +32,13 @@ class SampledRDD[T: ClassManifest](
override def preferredLocations(split: Split) =
prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
- override def compute(splitIn: Split, taskContext: TaskContext) = {
+ override def compute(splitIn: Split, context: TaskContext) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
- prev.iterator(split.prev, taskContext).flatMap { element =>
+ prev.iterator(split.prev, context).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
@@ -48,7 +48,7 @@ class SampledRDD[T: ClassManifest](
}
} else { // Sampling without replacement
val rand = new Random(split.seed)
- prev.iterator(split.prev, taskContext).filter(x => (rand.nextDouble <= frac))
+ prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index c736e92117..f832633646 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -31,7 +31,7 @@ class ShuffledRDD[K, V](
val dep = new ShuffleDependency(parent, part)
override val dependencies = List(dep)
- override def compute(split: Split, taskContext: TaskContext): Iterator[(K, V)] = {
+ override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
}
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 4b9cab8774..a08473f7be 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -12,7 +12,7 @@ private[spark] class UnionSplit[T: ClassManifest](
extends Split
with Serializable {
- def iterator(taskContext: TaskContext) = rdd.iterator(split, taskContext)
+ def iterator(context: TaskContext) = rdd.iterator(split, context)
def preferredLocations() = rdd.preferredLocations(split)
override val index: Int = idx
}
@@ -47,8 +47,8 @@ class UnionRDD[T: ClassManifest](
deps.toList
}
- override def compute(s: Split, taskContext: TaskContext): Iterator[T] =
- s.asInstanceOf[UnionSplit[T]].iterator(taskContext)
+ override def compute(s: Split, context: TaskContext): Iterator[T] =
+ s.asInstanceOf[UnionSplit[T]].iterator(context)
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index b987ca5fdf..92d667ff1e 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -12,8 +12,8 @@ private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
extends Split
with Serializable {
- def iterator(taskContext: TaskContext): Iterator[(T, U)] =
- rdd1.iterator(split1, taskContext).zip(rdd2.iterator(split2, taskContext))
+ def iterator(context: TaskContext): Iterator[(T, U)] =
+ rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
def preferredLocations(): Seq[String] =
rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
@@ -45,8 +45,8 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
@transient
override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
- override def compute(s: Split, taskContext: TaskContext): Iterator[(T, U)] =
- s.asInstanceOf[ZippedSplit[T, U]].iterator(taskContext)
+ override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
+ s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()