From 9f964612a1e3f1c80de52e1015dee510489ad8ed Mon Sep 17 00:00:00 2001 From: Peter Sankauskas Date: Mon, 10 Dec 2012 17:44:09 -0800 Subject: SPARK-626: Remove rules before removing security groups, with a pause in between so wait for AWS eventual consistency to catch up. --- ec2/spark_ec2.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 2ab11dbd34..2e8d2e17f5 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -557,18 +557,22 @@ def main(): inst.terminate() # Delete security groups as well group_names = [cluster_name + "-master", cluster_name + "-slaves", cluster_name + "-zoo"] - groups = conn.get_all_security_groups() + groups = [g for g in conn.get_all_security_groups() if g.name in group_names] + # Delete individual rules in all groups before deleting groups to remove + # dependencies between them for group in groups: - if group.name in group_names: - print "Deleting security group " + group.name - # Delete individual rules before deleting group to remove dependencies - for rule in group.rules: - for grant in rule.grants: - group.revoke(ip_protocol=rule.ip_protocol, - from_port=rule.from_port, - to_port=rule.to_port, - src_group=grant) - conn.delete_security_group(group.name) + print "Deleting rules in security group " + group.name + for rule in group.rules: + for grant in rule.grants: + group.revoke(ip_protocol=rule.ip_protocol, + from_port=rule.from_port, + to_port=rule.to_port, + src_group=grant) + # Sleep for AWS eventual-consistency to catch up + time.sleep(30) # Yes, it does have to be this long :-( + for group in groups: + print "Deleting security group " + group.name + conn.delete_security_group(group.name) elif action == "login": (master_nodes, slave_nodes, zoo_nodes) = get_existing_cluster( -- cgit v1.2.3 From 02d64f966252970ffee393b1f287666da374d237 Mon Sep 17 00:00:00 2001 From: Thomas Dudziak Date: Mon, 10 Dec 2012 21:27:54 -0800 Subject: Mark hadoop dependencies provided in all library artifacts --- bagel/pom.xml | 3 +++ examples/pom.xml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/bagel/pom.xml b/bagel/pom.xml index b462801589..a8256a6e8b 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -55,6 +55,7 @@ org.apache.hadoop hadoop-core + provided @@ -81,10 +82,12 @@ org.apache.hadoop hadoop-core + provided org.apache.hadoop hadoop-client + provided diff --git a/examples/pom.xml b/examples/pom.xml index d2643f046c..782c026d73 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -55,6 +55,7 @@ org.apache.hadoop hadoop-core + provided @@ -81,10 +82,12 @@ org.apache.hadoop hadoop-core + provided org.apache.hadoop hadoop-client + provided -- cgit v1.2.3 From f97ce3ae14ed05b3e5d3e6cd137ee5164813634e Mon Sep 17 00:00:00 2001 From: Peter Sankauskas Date: Tue, 11 Dec 2012 10:48:21 -0800 Subject: SPARK-626: Making security group deletion optional, handling retried when deleting security groups fails, fixing bug when using all zones but only 1 slave. --- ec2/spark_ec2.py | 82 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 2e8d2e17f5..2cc8431238 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -30,6 +30,7 @@ import time import urllib2 from optparse import OptionParser from sys import stderr +import boto from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType from boto import ec2 @@ -85,6 +86,8 @@ def parse_args(): help="'mesos' for a mesos cluster, 'standalone' for a standalone spark cluster (default: mesos)") parser.add_option("-u", "--user", default="root", help="The ssh user you want to connect as (default: root)") + parser.add_option("--delete-groups", action="store_true", default=False, + help="When destroying a cluster, also destroy the security groups that were created") (opts, args) = parser.parse_args() if len(args) != 2: @@ -283,16 +286,17 @@ def launch_cluster(conn, opts, cluster_name): slave_nodes = [] for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - slave_res = image.run(key_name = opts.key_pair, - security_groups = [slave_group], - instance_type = opts.instance_type, - placement = zone, - min_count = num_slaves_this_zone, - max_count = num_slaves_this_zone, - block_device_map = block_map) - slave_nodes += slave_res.instances - print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, - zone, slave_res.id) + if num_slaves_this_zone > 0: + slave_res = image.run(key_name = opts.key_pair, + security_groups = [slave_group], + instance_type = opts.instance_type, + placement = zone, + min_count = num_slaves_this_zone, + max_count = num_slaves_this_zone, + block_device_map = block_map) + slave_nodes += slave_res.instances + print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, + zone, slave_res.id) i += 1 # Launch masters @@ -555,24 +559,48 @@ def main(): print "Terminating zoo..." for inst in zoo_nodes: inst.terminate() + # Delete security groups as well - group_names = [cluster_name + "-master", cluster_name + "-slaves", cluster_name + "-zoo"] - groups = [g for g in conn.get_all_security_groups() if g.name in group_names] - # Delete individual rules in all groups before deleting groups to remove - # dependencies between them - for group in groups: - print "Deleting rules in security group " + group.name - for rule in group.rules: - for grant in rule.grants: - group.revoke(ip_protocol=rule.ip_protocol, - from_port=rule.from_port, - to_port=rule.to_port, - src_group=grant) - # Sleep for AWS eventual-consistency to catch up - time.sleep(30) # Yes, it does have to be this long :-( - for group in groups: - print "Deleting security group " + group.name - conn.delete_security_group(group.name) + if opts.delete_groups: + print "Deleting security groups (this will take some time)..." + group_names = [cluster_name + "-master", cluster_name + "-slaves", cluster_name + "-zoo"] + + attempt = 1; + while attempt <= 3: + print "Attempt %d" % attempt + groups = [g for g in conn.get_all_security_groups() if g.name in group_names] + success = True + # Delete individual rules in all groups before deleting groups to + # remove dependencies between them + for group in groups: + print "Deleting rules in security group " + group.name + for rule in group.rules: + for grant in rule.grants: + success &= group.revoke(ip_protocol=rule.ip_protocol, + from_port=rule.from_port, + to_port=rule.to_port, + src_group=grant) + + # Sleep for AWS eventual-consistency to catch up, and for instances + # to terminate + time.sleep(30) # Yes, it does have to be this long :-( + for group in groups: + try: + conn.delete_security_group(group.name) + print "Deleted security group " + group.name + except boto.exception.EC2ResponseError: + success = False; + print "Failed to delete security group " + group.name + + # Unfortunately, group.revoke() returns True even if a rule was not + # deleted, so this needs to be rerun if something fails + if success: break; + + attempt += 1 + + if not success: + print "Failed to delete all security groups after 3 tries." + print "Try re-running in a few minutes." elif action == "login": (master_nodes, slave_nodes, zoo_nodes) = get_existing_cluster( -- cgit v1.2.3 From eacb98e90075ca3082ad7c832b24719f322d9eb2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Dec 2012 15:41:53 -0800 Subject: SPARK-635: Pass a TaskContext object to compute() interface and use that to close Hadoop input stream. --- core/src/main/scala/spark/CacheTracker.scala | 30 ++++++++++++---------- core/src/main/scala/spark/PairRDDFunctions.scala | 15 ++++++----- core/src/main/scala/spark/ParallelCollection.scala | 17 ++++++------ core/src/main/scala/spark/RDD.scala | 8 +++--- core/src/main/scala/spark/TaskContext.scala | 19 +++++++++++++- .../main/scala/spark/api/java/JavaRDDLike.scala | 25 +++++++++--------- core/src/main/scala/spark/rdd/BlockRDD.scala | 25 ++++++++---------- core/src/main/scala/spark/rdd/CartesianRDD.scala | 17 ++++++------ core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 30 +++++++++------------- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 9 +++---- core/src/main/scala/spark/rdd/FilteredRDD.scala | 8 +++--- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 10 ++++---- core/src/main/scala/spark/rdd/GlommedRDD.scala | 8 +++--- core/src/main/scala/spark/rdd/HadoopRDD.scala | 22 ++++++++-------- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 10 ++++---- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 7 +++-- core/src/main/scala/spark/rdd/MappedRDD.scala | 9 +++---- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 27 +++++++++---------- core/src/main/scala/spark/rdd/PipedRDD.scala | 11 +++----- core/src/main/scala/spark/rdd/SampledRDD.scala | 15 +++++------ core/src/main/scala/spark/rdd/ShuffledRDD.scala | 9 +++---- core/src/main/scala/spark/rdd/UnionRDD.scala | 22 ++++++++-------- core/src/main/scala/spark/rdd/ZippedRDD.scala | 19 +++++++------- .../main/scala/spark/scheduler/DAGScheduler.scala | 17 ++++++------ .../main/scala/spark/scheduler/ResultTask.scala | 6 +++-- .../scala/spark/scheduler/ShuffleMapTask.scala | 17 +++++++----- 26 files changed, 207 insertions(+), 205 deletions(-) diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index c5db6ce63a..e9c545a2cf 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -1,5 +1,9 @@ package spark +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + import akka.actor._ import akka.dispatch._ import akka.pattern.ask @@ -8,10 +12,6 @@ import akka.util.Duration import akka.util.Timeout import akka.util.duration._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - import spark.storage.BlockManager import spark.storage.StorageLevel @@ -41,7 +41,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - + def receive = { case SlaveCacheStarted(host: String, size: Long) => slaveCapacity.put(host, size) @@ -92,14 +92,14 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) extends Logging { - + // Tracker actor on the master, or remote reference to it on workers val ip: String = System.getProperty("spark.master.host", "localhost") val port: Int = System.getProperty("spark.master.port", "7077").toInt val actorName: String = "CacheTracker" val timeout = 10.seconds - + var trackerActor: ActorRef = if (isMaster) { val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) logInfo("Registered CacheTrackerActor actor") @@ -132,7 +132,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b throw new SparkException("Error reply received from CacheTracker") } } - + // Registers an RDD (on master only) def registerRDD(rddId: Int, numPartitions: Int) { registeredRddIds.synchronized { @@ -143,7 +143,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b } } } - + // For BlockManager.scala only def cacheLost(host: String) { communicate(MemoryCacheLost(host)) @@ -155,19 +155,21 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b def getCacheStatus(): Seq[(String, Long, Long)] = { askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] } - + // For BlockManager.scala only def notifyFromBlockManager(t: AddedToCache) { communicate(t) } - + // Get a snapshot of the currently known locations def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] } - + // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = { + def getOrCompute[T]( + rdd: RDD[T], split: Split, taskContext: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) blockManager.get(key) match { @@ -209,7 +211,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b // TODO: also register a listener for when it unloads logInfo("Computing partition " + split) val elements = new ArrayBuffer[Any] - elements ++= rdd.compute(split) + elements ++= rdd.compute(split, taskContext) try { // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index e5bb639cfd..08ae06e865 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -35,11 +35,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( with Serializable { /** - * Generic function to combine the elements for each key using a custom set of aggregation + * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C * Note that V and C can be different -- for example, one might group an RDD of type * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: - * + * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) * - `mergeCombiners`, to combine two C's into a single one. @@ -118,7 +118,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** Count the number of elements for each key, and return the result to the master as a Map. */ def countByKey(): Map[K, Long] = self.map(_._1).countByValue() - /** + /** * (Experimental) Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ @@ -224,7 +224,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } } - /** + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the default * parallelism level. */ @@ -628,7 +628,8 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)] override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) override val partitioner = prev.partitioner - override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))} + override def compute(split: Split, taskContext: TaskContext) = + prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))} } private[spark] @@ -639,8 +640,8 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U] override val dependencies = List(new OneToOneDependency(prev)) override val partitioner = prev.partitioner - override def compute(split: Split) = { - prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } + override def compute(split: Split, taskContext: TaskContext) = { + prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9b57ae3b4f..a27f766e31 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -8,8 +8,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( val slice: Int, values: Seq[T]) extends Split with Serializable { - - def iterator(): Iterator[T] = values.iterator + + def iterator: Iterator[T] = values.iterator override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt @@ -22,7 +22,7 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - sc: SparkContext, + sc: SparkContext, @transient data: Seq[T], numSlices: Int) extends RDD[T](sc) { @@ -38,17 +38,18 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def splits = splits_.asInstanceOf[Array[Split]] - override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - + override def compute(s: Split, taskContext: TaskContext) = + s.asInstanceOf[ParallelCollectionSplit[T]].iterator + override def preferredLocations(s: Split): Seq[String] = Nil - + override val dependencies: List[Dependency[_]] = Nil } private object ParallelCollection { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range - * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes + * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes * it efficient to run Spark over RDDs representing large sets of numbers. */ def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { @@ -58,7 +59,7 @@ private object ParallelCollection { seq match { case r: Range.Inclusive => { val sign = if (r.step < 0) { - -1 + -1 } else { 1 } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6270e018b3..c53eab67e5 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -81,7 +81,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def splits: Array[Split] /** Function for computing a given partition. */ - def compute(split: Split): Iterator[T] + def compute(split: Split, taskContext: TaskContext): Iterator[T] /** How this RDD depends on any parent RDDs. */ @transient val dependencies: List[Dependency[_]] @@ -155,11 +155,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - final def iterator(split: Split): Iterator[T] = { + final def iterator(split: Split, taskContext: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) + SparkEnv.get.cacheTracker.getOrCompute[T](this, split, taskContext, storageLevel) } else { - compute(split) + compute(split, taskContext) } } diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index c14377d17b..b352db8167 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -1,3 +1,20 @@ package spark -class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable +import scala.collection.mutable.ArrayBuffer + + +class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { + + @transient + val onCompleteCallbacks = new ArrayBuffer[Unit => Unit] + + // Add a callback function to be executed on task completion. An example use + // is for HadoopRDD to register a callback to close the input stream. + def registerOnCompleteCallback(f: Unit => Unit) { + onCompleteCallbacks += f + } + + def executeOnCompleteCallbacks() { + onCompleteCallbacks.foreach{_()} + } +} diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 482eb9281a..81d3a94466 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -1,16 +1,15 @@ package spark.api.java -import spark.{SparkContext, Split, RDD} +import java.util.{List => JList} +import scala.Tuple2 +import scala.collection.JavaConversions._ + +import spark.{SparkContext, Split, RDD, TaskContext} import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import spark.partial.{PartialResult, BoundedDouble} import spark.storage.StorageLevel -import java.util.{List => JList} - -import scala.collection.JavaConversions._ -import java.{util, lang} -import scala.Tuple2 trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This @@ -24,7 +23,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** The [[spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context - + /** A unique ID for this RDD (within its SparkContext). */ def id: Int = rdd.id @@ -36,7 +35,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split)) + def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] = + asJavaIterator(rdd.iterator(split, taskContext)) // Transformations (return a new RDD) @@ -99,7 +99,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) } - /** * Return a new RDD by applying a function to each partition of this RDD. */ @@ -183,7 +182,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } // Actions (launch a job to return a value to the user program) - + /** * Applies a function f to all elements of this RDD. */ @@ -200,7 +199,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { val arr: java.util.Collection[T] = rdd.collect().toSeq new java.util.ArrayList(arr) } - + /** * Reduces the elements of this RDD using the specified associative binary operator. */ @@ -208,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to * modify t1 and return it as its result value to avoid object allocation; however, it should not * modify t2. */ @@ -251,7 +250,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2))))) + mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index cb73976aed..8209c36871 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -2,11 +2,8 @@ package spark.rdd import scala.collection.mutable.HashMap -import spark.Dependency -import spark.RDD -import spark.SparkContext -import spark.SparkEnv -import spark.Split +import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext} + private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split { val index = idx @@ -19,29 +16,29 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St @transient val splits_ = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] - }).toArray - - @transient + }).toArray + + @transient lazy val locations_ = { - val blockManager = SparkEnv.get.blockManager + val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ - val locations = blockManager.getLocations(blockIds) + val locations = blockManager.getLocations(blockIds) HashMap(blockIds.zip(locations):_*) } override def splits = splits_ - override def compute(split: Split): Iterator[T] = { - val blockManager = SparkEnv.get.blockManager + override def compute(split: Split, taskContext: TaskContext): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager val blockId = split.asInstanceOf[BlockRDDSplit].blockId blockManager.get(blockId) match { case Some(block) => block.asInstanceOf[Iterator[T]] - case None => + case None => throw new Exception("Could not compute split, block " + blockId + " not found") } } - override def preferredLocations(split: Split) = + override def preferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) override val dependencies: List[Dependency[_]] = Nil diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 7c354b6b2e..6bc0938ce2 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,9 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext} + private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { @@ -17,9 +15,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( rdd2: RDD[U]) extends RDD[Pair[T, U]](sc) with Serializable { - + val numSplitsInRdd2 = rdd2.splits.size - + @transient val splits_ = { // create the cross product split @@ -38,11 +36,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } - override def compute(split: Split) = { + override def compute(split: Split, taskContext: TaskContext) = { val currSplit = split.asInstanceOf[CartesianSplit] - for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) + for (x <- rdd1.iterator(currSplit.s1, taskContext); + y <- rdd2.iterator(currSplit.s2, taskContext)) yield (x, y) } - + override val dependencies = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 50bec9e63b..6037681cfd 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -3,21 +3,15 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import spark.Aggregator -import spark.Dependency -import spark.Logging -import spark.OneToOneDependency -import spark.Partitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split +import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} +import spark.{Dependency, OneToOneDependency, ShuffleDependency} + private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep -private[spark] +private[spark] class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -32,9 +26,9 @@ private[spark] class CoGroupAggregator class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { - + val aggr = new CoGroupAggregator - + @transient override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] @@ -50,7 +44,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } deps.toList } - + @transient val splits_ : Array[Split] = { val firstRdd = rdds.head @@ -69,12 +63,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } override def splits = splits_ - + override val partitioner = Some(part) - + override def preferredLocations(s: Split) = Nil - - override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { + + override def compute(s: Split, taskContext: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size val map = new HashMap[K, Seq[ArrayBuffer[Any]]] @@ -84,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { // Read them from the parent - for ((k, v) <- rdd.iterator(itsSplit)) { + for ((k, v) <- rdd.iterator(itsSplit, taskContext)) { getSeq(k.asInstanceOf[K])(depNum) += v } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0967f4f5df..06ffc9c42c 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,8 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.Split +import spark.{NarrowDependency, RDD, Split, TaskContext} + private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split @@ -32,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) override def splits = splits_ - override def compute(split: Split): Iterator[T] = { + override def compute(split: Split, taskContext: TaskContext): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit) + parentSplit => prev.iterator(parentSplit, taskContext) } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index dfe9dc73f3..14a80d82c7 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -1,12 +1,12 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).filter(f) + override def compute(split: Split, taskContext: TaskContext) = + prev.iterator(split, taskContext).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 3534dc8057..64f8c51d6d 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -1,16 +1,16 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: T => TraversableOnce[U]) extends RDD[U](prev.context) { - + override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).flatMap(f) + + override def compute(split: Split, taskContext: TaskContext) = + prev.iterator(split, taskContext).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index e30564f2da..d6b1b27d3e 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -1,12 +1,12 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator + override def compute(split: Split, taskContext: TaskContext) = + Array(prev.iterator(split, taskContext).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index bf29a1f075..c6c035a096 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -15,19 +15,16 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import spark.Dependency -import spark.RDD -import spark.SerializableWritable -import spark.SparkContext -import spark.Split +import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext} -/** + +/** * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) extends Split with Serializable { - + val inputSplit = new SerializableWritable[InputSplit](s) override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt @@ -47,10 +44,10 @@ class HadoopRDD[K, V]( valueClass: Class[V], minSplits: Int) extends RDD[(K, V)](sc) { - + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) - + @transient val splits_ : Array[Split] = { val inputFormat = createInputFormat(conf) @@ -69,7 +66,7 @@ class HadoopRDD[K, V]( override def splits = splits_ - override def compute(theSplit: Split) = new Iterator[(K, V)] { + override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] var reader: RecordReader[K, V] = null @@ -77,6 +74,9 @@ class HadoopRDD[K, V]( val fmt = createInputFormat(conf) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) + // Register an on-task-completion callback to close the input stream. + taskContext.registerOnCompleteCallback(Unit => reader.close()) + val key: K = reader.createKey() val value: V = reader.createValue() var gotNext = false @@ -115,6 +115,6 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - + override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index a904ef62c3..715c240060 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -1,8 +1,7 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( @@ -12,8 +11,9 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( extends RDD[U](prev.context) { override val partitioner = if (preservesPartitioning) prev.partitioner else None - + override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(prev.iterator(split)) + override def compute(split: Split, taskContext: TaskContext) = + f(prev.iterator(split, taskContext)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 14e390c43b..39f3c7b5f7 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -1,8 +1,6 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -19,5 +17,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) prev.partitioner else None override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(split.index, prev.iterator(split)) + override def compute(split: Split, taskContext: TaskContext) = + f(split.index, prev.iterator(split, taskContext)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 59bedad8ef..d82ab3f671 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -1,16 +1,15 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: T => U) extends RDD[U](prev.context) { - + override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).map(f) + override def compute(split: Split, taskContext: TaskContext) = + prev.iterator(split, taskContext).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 7a1a0fb87d..61f4cbbe94 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -1,22 +1,19 @@ package spark.rdd +import java.text.SimpleDateFormat +import java.util.Date + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import java.util.Date -import java.text.SimpleDateFormat +import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext} -import spark.Dependency -import spark.RDD -import spark.SerializableWritable -import spark.SparkContext -import spark.Split -private[spark] +private[spark] class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) extends Split { - + val serializableHadoopSplit = new SerializableWritable(rawSplit) override def hashCode(): Int = (41 * (41 + rddId) + index) @@ -29,7 +26,7 @@ class NewHadoopRDD[K, V]( @transient conf: Configuration) extends RDD[(K, V)](sc) with HadoopMapReduceUtil { - + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) // private val serializableConf = new SerializableWritable(conf) @@ -56,7 +53,7 @@ class NewHadoopRDD[K, V]( override def splits = splits_ - override def compute(theSplit: Split) = new Iterator[(K, V)] { + override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) @@ -64,7 +61,10 @@ class NewHadoopRDD[K, V]( val format = inputFormatClass.newInstance val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) reader.initialize(split.serializableHadoopSplit.value, context) - + + // Register an on-task-completion callback to close the input stream. + taskContext.registerOnCompleteCallback(Unit => reader.close()) + var havePair = false var finished = false @@ -72,9 +72,6 @@ class NewHadoopRDD[K, V]( if (!finished && !havePair) { finished = !reader.nextKeyValue havePair = !finished - if (finished) { - reader.close() - } } !finished } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 98ea0c92d6..b34c7ea5b9 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -8,10 +8,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import spark.OneToOneDependency -import spark.RDD -import spark.SparkEnv -import spark.Split +import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} /** @@ -32,12 +29,12 @@ class PipedRDD[T: ClassManifest]( override val dependencies = List(new OneToOneDependency(parent)) - override def compute(split: Split): Iterator[String] = { + override def compute(split: Split, taskContext: TaskContext): Iterator[String] = { val pb = new ProcessBuilder(command) // Add the environmental variables to the process. val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } - + val proc = pb.start() val env = SparkEnv.get @@ -55,7 +52,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split)) { + for (elem <- parent.iterator(split, taskContext)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 87a5268f27..07a1487f3a 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -4,9 +4,8 @@ import java.util.Random import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import spark.RDD -import spark.OneToOneDependency -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -15,7 +14,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali class SampledRDD[T: ClassManifest]( prev: RDD[T], - withReplacement: Boolean, + withReplacement: Boolean, frac: Double, seed: Int) extends RDD[T](prev.context) { @@ -29,17 +28,17 @@ class SampledRDD[T: ClassManifest]( override def splits = splits_.asInstanceOf[Array[Split]] override val dependencies = List(new OneToOneDependency(prev)) - + override def preferredLocations(split: Split) = prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) - override def compute(splitIn: Split) = { + override def compute(splitIn: Split, taskContext: TaskContext) = { val split = splitIn.asInstanceOf[SampledRDDSplit] if (withReplacement) { // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev).flatMap { element => + prev.iterator(split.prev, taskContext).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) + prev.iterator(split.prev, taskContext).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 145e419c53..c736e92117 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,10 +1,7 @@ package spark.rdd -import spark.Partitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split +import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} + private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -34,7 +31,7 @@ class ShuffledRDD[K, V]( val dep = new ShuffleDependency(parent, part) override val dependencies = List(dep) - override def compute(split: Split): Iterator[(K, V)] = { + override def compute(split: Split, taskContext: TaskContext): Iterator[(K, V)] = { SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index f0b9225f7c..4b9cab8774 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -2,20 +2,17 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer -import spark.Dependency -import spark.RangeDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext} + private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, + idx: Int, rdd: RDD[T], split: Split) extends Split with Serializable { - - def iterator() = rdd.iterator(split) + + def iterator(taskContext: TaskContext) = rdd.iterator(split, taskContext) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx } @@ -25,7 +22,7 @@ class UnionRDD[T: ClassManifest]( @transient rdds: Seq[RDD[T]]) extends RDD[T](sc) with Serializable { - + @transient val splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) @@ -44,13 +41,14 @@ class UnionRDD[T: ClassManifest]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) + deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } deps.toList } - - override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() + + override def compute(s: Split, taskContext: TaskContext): Iterator[T] = + s.asInstanceOf[UnionSplit[T]].iterator(taskContext) override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 80f0150c45..b987ca5fdf 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -1,21 +1,19 @@ package spark.rdd -import spark.Dependency -import spark.OneToOneDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext} + private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( - idx: Int, + idx: Int, rdd1: RDD[T], rdd2: RDD[U], split1: Split, split2: Split) extends Split with Serializable { - - def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2)) + + def iterator(taskContext: TaskContext): Iterator[(T, U)] = + rdd1.iterator(split1, taskContext).zip(rdd2.iterator(split2, taskContext)) def preferredLocations(): Seq[String] = rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) @@ -46,8 +44,9 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( @transient override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)) - - override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator() + + override def compute(s: Split, taskContext: TaskContext): Iterator[(T, U)] = + s.asInstanceOf[ZippedSplit[T, U]].iterator(taskContext) override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[ZippedSplit[T, U]].preferredLocations() diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 5c71207d43..29757b1178 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -16,8 +16,8 @@ import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId /** - * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for - * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal + * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for + * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ @@ -73,7 +73,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; // that's not going to be a realistic assumption in general - + val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures @@ -94,7 +94,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { cacheLocs(rdd.id) } - + def updateCacheLocs() { cacheLocs = cacheTracker.getLocationsSnapshot() } @@ -326,7 +326,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val rdd = job.finalStage.rdd val split = rdd.splits(job.partitions(0)) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - val result = job.func(taskContext, rdd.iterator(split)) + val result = job.func(taskContext, rdd.iterator(split, taskContext)) + taskContext.executeOnCompleteCallbacks() job.listener.taskSucceeded(0, result) } catch { case e: Exception => @@ -353,7 +354,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } } - + def submitMissingTasks(stage: Stage) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry @@ -395,7 +396,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val task = event.task val stage = idToStage(task.stageId) event.reason match { - case Success => + case Success => logInfo("Completed " + task) if (event.accumUpdates != null) { Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted @@ -519,7 +520,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with updateCacheLocs() } } - + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 2ebd4075a2..e492279b4e 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -10,12 +10,14 @@ private[spark] class ResultTask[T, U]( @transient locs: Seq[String], val outputId: Int) extends Task[U](stageId) { - + val split = rdd.splits(partition) override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) - func(context, rdd.iterator(split)) + val result = func(context, rdd.iterator(split, context)) + context.executeOnCompleteCallbacks() + result } override def preferredLocations: Seq[String] = locs diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 60105c42b6..bd1911fce2 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -70,19 +70,19 @@ private[spark] object ShuffleMapTask { private[spark] class ShuffleMapTask( stageId: Int, - var rdd: RDD[_], + var rdd: RDD[_], var dep: ShuffleDependency[_,_], - var partition: Int, + var partition: Int, @transient var locs: Seq[String]) extends Task[MapStatus](stageId) with Externalizable with Logging { def this() = this(0, null, null, 0, null) - + var split = if (rdd == null) { - null - } else { + null + } else { rdd.splits(partition) } @@ -113,9 +113,11 @@ private[spark] class ShuffleMapTask( val numOutputSplits = dep.partitioner.numPartitions val partitioner = dep.partitioner + val taskContext = new TaskContext(stageId, partition, attemptId) + // Partition the map output. val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) - for (elem <- rdd.iterator(split)) { + for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = partitioner.getPartition(pair._1) buckets(bucketId) += pair @@ -133,6 +135,9 @@ private[spark] class ShuffleMapTask( compressedSizes(i) = MapOutputTracker.compressSize(size) } + // Execute the callbacks on task completion. + taskContext.executeOnCompleteCallbacks() + return new MapStatus(blockManager.blockManagerId, compressedSizes) } -- cgit v1.2.3 From 4f076e105ee30edcb1941216c79d017c5175d9b8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Dec 2012 16:41:15 -0800 Subject: SPARK-635: Pass a TaskContext object to compute() interface and use that to close Hadoop input stream. Incorporated Matei's command. --- core/src/main/scala/spark/CacheTracker.scala | 5 ++--- core/src/main/scala/spark/RDD.scala | 8 ++++---- core/src/main/scala/spark/TaskContext.scala | 4 ++-- core/src/main/scala/spark/rdd/BlockRDD.scala | 2 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 6 +++--- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 4 ++-- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 4 ++-- core/src/main/scala/spark/rdd/FilteredRDD.scala | 3 +-- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 4 ++-- core/src/main/scala/spark/rdd/GlommedRDD.scala | 4 ++-- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 ++-- core/src/main/scala/spark/rdd/MapPartitionsRDD.scala | 3 +-- core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala | 4 ++-- core/src/main/scala/spark/rdd/MappedRDD.scala | 3 +-- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 11 ++++++----- core/src/main/scala/spark/rdd/PipedRDD.scala | 4 ++-- core/src/main/scala/spark/rdd/SampledRDD.scala | 6 +++--- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 2 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 6 +++--- core/src/main/scala/spark/rdd/ZippedRDD.scala | 8 ++++---- 20 files changed, 46 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index e9c545a2cf..3d79078733 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -167,8 +167,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b } // Gets or computes an RDD split - def getOrCompute[T]( - rdd: RDD[T], split: Split, taskContext: TaskContext, storageLevel: StorageLevel) + def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) : Iterator[T] = { val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) @@ -211,7 +210,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b // TODO: also register a listener for when it unloads logInfo("Computing partition " + split) val elements = new ArrayBuffer[Any] - elements ++= rdd.compute(split, taskContext) + elements ++= rdd.compute(split, context) try { // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index c53eab67e5..bb4c13c494 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -81,7 +81,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def splits: Array[Split] /** Function for computing a given partition. */ - def compute(split: Split, taskContext: TaskContext): Iterator[T] + def compute(split: Split, context: TaskContext): Iterator[T] /** How this RDD depends on any parent RDDs. */ @transient val dependencies: List[Dependency[_]] @@ -155,11 +155,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - final def iterator(split: Split, taskContext: TaskContext): Iterator[T] = { + final def iterator(split: Split, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, taskContext, storageLevel) + SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel) } else { - compute(split, taskContext) + compute(split, context) } } diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index b352db8167..d2746b26b3 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -6,11 +6,11 @@ import scala.collection.mutable.ArrayBuffer class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { @transient - val onCompleteCallbacks = new ArrayBuffer[Unit => Unit] + val onCompleteCallbacks = new ArrayBuffer[() => Unit] // Add a callback function to be executed on task completion. An example use // is for HadoopRDD to register a callback to close the input stream. - def registerOnCompleteCallback(f: Unit => Unit) { + def addOnCompleteCallback(f: () => Unit) { onCompleteCallbacks += f } diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 8209c36871..f98528a183 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -28,7 +28,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St override def splits = splits_ - override def compute(split: Split, taskContext: TaskContext): Iterator[T] = { + override def compute(split: Split, context: TaskContext): Iterator[T] = { val blockManager = SparkEnv.get.blockManager val blockId = split.asInstanceOf[BlockRDDSplit].blockId blockManager.get(blockId) match { diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 6bc0938ce2..4a7e5f3d06 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -36,10 +36,10 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } - override def compute(split: Split, taskContext: TaskContext) = { + override def compute(split: Split, context: TaskContext) = { val currSplit = split.asInstanceOf[CartesianSplit] - for (x <- rdd1.iterator(currSplit.s1, taskContext); - y <- rdd2.iterator(currSplit.s2, taskContext)) yield (x, y) + for (x <- rdd1.iterator(currSplit.s1, context); + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } override val dependencies = List( diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 6037681cfd..de0d9fad88 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -68,7 +68,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) override def preferredLocations(s: Split) = Nil - override def compute(s: Split, taskContext: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size val map = new HashMap[K, Seq[ArrayBuffer[Any]]] @@ -78,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { // Read them from the parent - for ((k, v) <- rdd.iterator(itsSplit, taskContext)) { + for ((k, v) <- rdd.iterator(itsSplit, context)) { getSeq(k.asInstanceOf[K])(depNum) += v } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 06ffc9c42c..1affe0e0ef 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -31,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) override def splits = splits_ - override def compute(split: Split, taskContext: TaskContext): Iterator[T] = { + override def compute(split: Split, context: TaskContext): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit, taskContext) + parentSplit => prev.iterator(parentSplit, context) } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 14a80d82c7..b148da28de 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -7,6 +7,5 @@ private[spark] class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, taskContext: TaskContext) = - prev.iterator(split, taskContext).filter(f) + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 64f8c51d6d..785662b2da 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -11,6 +11,6 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, taskContext: TaskContext) = - prev.iterator(split, taskContext).flatMap(f) + override def compute(split: Split, context: TaskContext) = + prev.iterator(split, context).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index d6b1b27d3e..fac8ffb4cb 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -7,6 +7,6 @@ private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, taskContext: TaskContext) = - Array(prev.iterator(split, taskContext).toArray).iterator + override def compute(split: Split, context: TaskContext) = + Array(prev.iterator(split, context).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index c6c035a096..ab163f569b 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -66,7 +66,7 @@ class HadoopRDD[K, V]( override def splits = splits_ - override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] { + override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] var reader: RecordReader[K, V] = null @@ -75,7 +75,7 @@ class HadoopRDD[K, V]( reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. - taskContext.registerOnCompleteCallback(Unit => reader.close()) + context.addOnCompleteCallback(() => reader.close()) val key: K = reader.createKey() val value: V = reader.createValue() diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index 715c240060..c764505345 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -14,6 +14,5 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, taskContext: TaskContext) = - f(prev.iterator(split, taskContext)) + override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 39f3c7b5f7..3d9888bd34 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -17,6 +17,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) prev.partitioner else None override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, taskContext: TaskContext) = - f(split.index, prev.iterator(split, taskContext)) + override def compute(split: Split, context: TaskContext) = + f(split.index, prev.iterator(split, context)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index d82ab3f671..70fa8f4497 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -10,6 +10,5 @@ class MappedRDD[U: ClassManifest, T: ClassManifest]( override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, taskContext: TaskContext) = - prev.iterator(split, taskContext).map(f) + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 61f4cbbe94..197ed5ea17 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -53,17 +53,18 @@ class NewHadoopRDD[K, V]( override def splits = splits_ - override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] { + override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) - val context = newTaskAttemptContext(conf, attemptId) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance - val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) - reader.initialize(split.serializableHadoopSplit.value, context) + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. - taskContext.registerOnCompleteCallback(Unit => reader.close()) + context.addOnCompleteCallback(() => reader.close()) var havePair = false var finished = false diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index b34c7ea5b9..336e193217 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -29,7 +29,7 @@ class PipedRDD[T: ClassManifest]( override val dependencies = List(new OneToOneDependency(parent)) - override def compute(split: Split, taskContext: TaskContext): Iterator[String] = { + override def compute(split: Split, context: TaskContext): Iterator[String] = { val pb = new ProcessBuilder(command) // Add the environmental variables to the process. val currentEnvVars = pb.environment() @@ -52,7 +52,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split, taskContext)) { + for (elem <- parent.iterator(split, context)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 07a1487f3a..6e4797aabb 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -32,13 +32,13 @@ class SampledRDD[T: ClassManifest]( override def preferredLocations(split: Split) = prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) - override def compute(splitIn: Split, taskContext: TaskContext) = { + override def compute(splitIn: Split, context: TaskContext) = { val split = splitIn.asInstanceOf[SampledRDDSplit] if (withReplacement) { // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev, taskContext).flatMap { element => + prev.iterator(split.prev, context).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -48,7 +48,7 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev, taskContext).filter(x => (rand.nextDouble <= frac)) + prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index c736e92117..f832633646 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -31,7 +31,7 @@ class ShuffledRDD[K, V]( val dep = new ShuffleDependency(parent, part) override val dependencies = List(dep) - override def compute(split: Split, taskContext: TaskContext): Iterator[(K, V)] = { + override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 4b9cab8774..a08473f7be 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -12,7 +12,7 @@ private[spark] class UnionSplit[T: ClassManifest]( extends Split with Serializable { - def iterator(taskContext: TaskContext) = rdd.iterator(split, taskContext) + def iterator(context: TaskContext) = rdd.iterator(split, context) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx } @@ -47,8 +47,8 @@ class UnionRDD[T: ClassManifest]( deps.toList } - override def compute(s: Split, taskContext: TaskContext): Iterator[T] = - s.asInstanceOf[UnionSplit[T]].iterator(taskContext) + override def compute(s: Split, context: TaskContext): Iterator[T] = + s.asInstanceOf[UnionSplit[T]].iterator(context) override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index b987ca5fdf..92d667ff1e 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -12,8 +12,8 @@ private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( extends Split with Serializable { - def iterator(taskContext: TaskContext): Iterator[(T, U)] = - rdd1.iterator(split1, taskContext).zip(rdd2.iterator(split2, taskContext)) + def iterator(context: TaskContext): Iterator[(T, U)] = + rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) def preferredLocations(): Seq[String] = rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) @@ -45,8 +45,8 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( @transient override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)) - override def compute(s: Split, taskContext: TaskContext): Iterator[(T, U)] = - s.asInstanceOf[ZippedSplit[T, U]].iterator(taskContext) + override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = + s.asInstanceOf[ZippedSplit[T, U]].iterator(context) override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[ZippedSplit[T, U]].preferredLocations() -- cgit v1.2.3