aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-25 15:33:26 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-25 15:33:26 -0800
commit2435b7b5b77fe3e30fc8175983460ceb7a70632e (patch)
treeb9de4688e11e588dd7c4f4a048544893e185daa1
parent04bfee2d08a566c06ebb6278b56556e914814497 (diff)
parent8efbda0b179e3821a1221c6d78681fc74248cdac (diff)
downloadspark-2435b7b5b77fe3e30fc8175983460ceb7a70632e.tar.gz
spark-2435b7b5b77fe3e30fc8175983460ceb7a70632e.tar.bz2
spark-2435b7b5b77fe3e30fc8175983460ceb7a70632e.zip
Merge pull request #416 from stephenh/morefinally
Call executeOnCompleteCallbacks in more finally blocks.
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala13
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala46
2 files changed, 30 insertions, 29 deletions
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index b320be8863..f599eb00bd 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -40,7 +40,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(HostLost(host))
}
- // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
+ // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -54,8 +54,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// resubmit failed stages
val POLL_TIMEOUT = 10L
- private val lock = new Object // Used for access to the entire DAGScheduler
-
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
val nextRunId = new AtomicInteger(0)
@@ -337,9 +335,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- val result = job.func(taskContext, rdd.iterator(split, taskContext))
- taskContext.executeOnCompleteCallbacks()
- job.listener.taskSucceeded(0, result)
+ try {
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ job.listener.taskSucceeded(0, result)
+ } finally {
+ taskContext.executeOnCompleteCallbacks()
+ }
} catch {
case e: Exception =>
job.listener.jobFailed(e)
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 19f5328eee..83641a2a84 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {
- def this() = this(0, null, null, 0, null)
+ protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
null
@@ -117,34 +117,34 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val partitioner = dep.partitioner
val taskContext = new TaskContext(stageId, partition, attemptId)
+ try {
+ // Partition the map output.
+ val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ for (elem <- rdd.iterator(split, taskContext)) {
+ val pair = elem.asInstanceOf[(Any, Any)]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ buckets(bucketId) += pair
+ }
+ val bucketIterators = buckets.map(_.iterator)
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
- for (elem <- rdd.iterator(split, taskContext)) {
- val pair = elem.asInstanceOf[(Any, Any)]
- val bucketId = partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
- }
- val bucketIterators = buckets.map(_.iterator)
+ val compressedSizes = new Array[Byte](numOutputSplits)
- val compressedSizes = new Array[Byte](numOutputSplits)
+ val blockManager = SparkEnv.get.blockManager
+ for (i <- 0 until numOutputSplits) {
+ val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
+ // Get a Scala iterator from Java map
+ val iter: Iterator[(Any, Any)] = bucketIterators(i)
+ val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ compressedSizes(i) = MapOutputTracker.compressSize(size)
+ }
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = bucketIterators(i)
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
- compressedSizes(i) = MapOutputTracker.compressSize(size)
+ return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } finally {
+ // Execute the callbacks on task completion.
+ taskContext.executeOnCompleteCallbacks()
}
-
- // Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
-
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
override def preferredLocations: Seq[String] = locs