aboutsummaryrefslogtreecommitdiff
path: root/src/scala/spark/Accumulators.scala
blob: ee93d3c85c0e9e318d8f0a1449a9fe81f4d6da07 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package spark

import java.io._

import scala.collection.mutable.Map

@serializable class Accumulator[T](
  @transient initialValue: T, param: AccumulatorParam[T])
{
  val id = Accumulators.newId
  @transient var value_ = initialValue // Current value on master
  val zero = param.zero(initialValue)  // Zero value to be passed to workers
  var deserialized = false

  Accumulators.register(this)

  def += (term: T) { value_ = param.addInPlace(value_, term) }
  def value = this.value_
  def value_= (t: T) {
    if (!deserialized) value_ = t
    else throw new UnsupportedOperationException("Can't use value_= in task")
  }
 
  // Called by Java when deserializing an object
  private def readObject(in: ObjectInputStream) {
    in.defaultReadObject
    value_ = zero
    deserialized = true
    Accumulators.register(this)
  }

  override def toString = value_.toString
}

@serializable trait AccumulatorParam[T] {
  def addInPlace(t1: T, t2: T): T
  def zero(initialValue: T): T
}

// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators
{
  // TODO: Use soft references? => need to make readObject work properly then
  val accums = Map[(Thread, Long), Accumulator[_]]()
  var lastId: Long = 0 
  
  def newId: Long = synchronized { lastId += 1; return lastId }

  def register(a: Accumulator[_]): Unit = synchronized { 
    accums((currentThread, a.id)) = a 
  }

  def clear: Unit = synchronized { 
    accums.retain((key, accum) => key._1 != currentThread)
  }

  def values: Map[Long, Any] = synchronized {
    val ret = Map[Long, Any]()
    for(((thread, id), accum) <- accums if thread == currentThread)
      ret(id) = accum.value
    return ret
  }

  def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
    for ((id, value) <- values) {
      if (accums.contains((thread, id))) {
        val accum = accums((thread, id))
        accum.asInstanceOf[Accumulator[Any]] += value
      }
    }
  }
}