diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-07-14 14:08:34 -0400 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-07-14 14:08:34 -0400 |
commit | 797b4547c3ee0cad522b733eeb65cfacbef2f08c (patch) | |
tree | 631196490070a632d62f2c0492c2fc38b729e18b /core | |
parent | 0ccfe20755665aa4c347b82e18297c5b3a2284ee (diff) | |
download | spark-797b4547c3ee0cad522b733eeb65cfacbef2f08c.tar.gz spark-797b4547c3ee0cad522b733eeb65cfacbef2f08c.tar.bz2 spark-797b4547c3ee0cad522b733eeb65cfacbef2f08c.zip |
Fix tracking of updates in accumulators to solve an issue that would manifest in the 2.9 interpreter
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/Accumulators.scala | 32 | ||||
-rw-r--r-- | core/src/main/scala/spark/DAGScheduler.scala | 2 |
2 files changed, 21 insertions, 13 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index ee93d3c85c..4f51826d9d 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -12,7 +12,7 @@ import scala.collection.mutable.Map val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false - Accumulators.register(this) + Accumulators.register(this, true) def += (term: T) { value_ = param.addInPlace(value_, term) } def value = this.value_ @@ -26,7 +26,7 @@ import scala.collection.mutable.Map in.defaultReadObject value_ = zero deserialized = true - Accumulators.register(this) + Accumulators.register(this, false) } override def toString = value_.toString @@ -42,31 +42,39 @@ import scala.collection.mutable.Map private object Accumulators { // TODO: Use soft references? => need to make readObject work properly then - val accums = Map[(Thread, Long), Accumulator[_]]() - var lastId: Long = 0 + val originals = Map[Long, Accumulator[_]]() + val localAccums = Map[Thread, Map[Long, Accumulator[_]]]() + var lastId: Long = 0 def newId: Long = synchronized { lastId += 1; return lastId } - def register(a: Accumulator[_]): Unit = synchronized { - accums((currentThread, a.id)) = a + def register(a: Accumulator[_], original: Boolean): Unit = synchronized { + if (original) { + originals(a.id) = a + } else { + val accums = localAccums.getOrElseUpdate(currentThread, Map()) + accums(a.id) = a + } } + // Clear the local (non-original) accumulators for the current thread def clear: Unit = synchronized { - accums.retain((key, accum) => key._1 != currentThread) + localAccums.remove(currentThread) } + // Get the values of the local accumulators for the current thread (by ID) def values: Map[Long, Any] = synchronized { val ret = Map[Long, Any]() - for(((thread, id), accum) <- accums if thread == currentThread) + for ((id, accum) <- localAccums.getOrElse(currentThread, Map())) ret(id) = accum.value return ret } - def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized { + // Add values to the original accumulators with some given IDs + def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { - if (accums.contains((thread, id))) { - val accum = accums((thread, id)) - accum.asInstanceOf[Accumulator[Any]] += value + if (originals.contains(id)) { + originals(id).asInstanceOf[Accumulator[Any]] += value } } } diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index 42bb3c2a75..93cab9fb62 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -225,7 +225,7 @@ private trait DAGScheduler extends Scheduler with Logging { if (evt.reason == Success) { // A task ended logInfo("Completed " + evt.task) - Accumulators.add(currentThread, evt.accumUpdates) + Accumulators.add(evt.accumUpdates) evt.task match { case rt: ResultTask[_, _] => results(rt.outputId) = evt.result.asInstanceOf[U] |