diff options
-rw-r--r-- | src/scala/spark/MesosScheduler.scala | 32 |
1 files changed, 20 insertions, 12 deletions
diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 081f720bbd..84b9d9af68 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -260,23 +260,31 @@ extends ParallelOperation def taskFinished(status: TaskStatus) { val tid = status.getTaskId println("Finished TID " + tid) - // Deserialize task result - val result = Utils.deserialize[TaskResult[T]](status.getData) - results(tidToIndex(tid)) = result.value - // Update accumulators - Accumulators.add(callingThread, result.accumUpdates) - // Mark finished and stop if we've finished all the tasks - finished(tidToIndex(tid)) = true - tasksFinished += 1 - if (tasksFinished == numTasks) - setAllFinished() + if (!finished(tidToIndex(tid))) { + // Deserialize task result + val result = Utils.deserialize[TaskResult[T]](status.getData) + results(tidToIndex(tid)) = result.value + // Update accumulators + Accumulators.add(callingThread, result.accumUpdates) + // Mark finished and stop if we've finished all the tasks + finished(tidToIndex(tid)) = true + tasksFinished += 1 + if (tasksFinished == numTasks) + setAllFinished() + } else { + printf("Task %s had already finished, so ignoring it\n", tidToIndex(tid)) + } } def taskLost(status: TaskStatus) { val tid = status.getTaskId println("Lost TID " + tid) - launched(tidToIndex(tid)) = false - tasksLaunched -= 1 + if (!finished(tid)) { + launched(tidToIndex(tid)) = false + tasksLaunched -= 1 + } else { + printf("Task %s had already finished, so ignoring it\n", tidToIndex(tid)) + } } def error(code: Int, message: String) { |