aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/Accumulators.scala
blob: bf774178526142976da87b035bd9574b0fbca2c5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package spark

import java.io._

import scala.collection.mutable.Map

class Accumulable[T,R] (
    @transient initialValue: T,
    param: AccumulableParam[T,R])
  extends Serializable {
  
  val id = Accumulators.newId
  @transient
  private var value_ = initialValue // Current value on master
  val zero = param.zero(initialValue)  // Zero value to be passed to workers
  var deserialized = false

  Accumulators.register(this, true)

  /**
   * add more data to this accumulator / accumulable
   * @param term the data to add
   */
  def += (term: R) { value_ = param.addAccumulator(value_, term) }

  /**
   * merge two accumulable objects together
   * 
   * Normally, a user will not want to use this version, but will instead call `+=`.
   * @param term the other Accumulable that will get merged with this
   */
  def ++= (term: T) { value_ = param.addInPlace(value_, term)}
  def value = {
    if (!deserialized) value_
    else throw new UnsupportedOperationException("Can't use read value in task")
  }

  private[spark] def localValue = value_

  def value_= (t: T) {
    if (!deserialized) value_ = t
    else throw new UnsupportedOperationException("Can't use value_= in task")
  }
 
  // Called by Java when deserializing an object
  private def readObject(in: ObjectInputStream) {
    in.defaultReadObject
    value_ = zero
    deserialized = true
    Accumulators.register(this, false)
  }

  override def toString = value_.toString
}

class Accumulator[T](
    @transient initialValue: T,
    param: AccumulatorParam[T]) extends Accumulable[T,T](initialValue, param)

/**
 * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type
 * as the accumulated value
 * @tparam T
 */
trait AccumulatorParam[T] extends AccumulableParam[T,T] {
  def addAccumulator(t1: T, t2: T) : T = {
    addInPlace(t1, t2)
  }
}

/**
 * A datatype that can be accumulated, ie. has a commutative & associative +.
 *
 * You must define how to add data, and how to merge two of these together.  For some datatypes, these might be
 * the same operation (eg., a counter).  In that case, you might want to use [[spark.AccumulatorParam]].  They won't
 * always be the same, though -- eg., imagine you are accumulating a set.  You will add items to the set, and you
 * will union two sets together.
 *
 * @tparam R the full accumulated data
 * @tparam T partial data that can be added in
 */
trait AccumulableParam[R,T] extends Serializable {
  /**
   * Add additional data to the accumulator value.
   * @param t1 the current value of the accumulator
   * @param t2 the data to be added to the accumulator
   * @return the new value of the accumulator
   */
  def addAccumulator(t1: R, t2: T) : R

  /**
   * merge two accumulated values together
   * @param t1 one set of accumulated data
   * @param t2 another set of accumulated data
   * @return both data sets merged together
   */
  def addInPlace(t1: R, t2: R): R

  def zero(initialValue: R): R
}

// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators {
  // TODO: Use soft references? => need to make readObject work properly then
  val originals = Map[Long, Accumulable[_,_]]()
  val localAccums = Map[Thread, Map[Long, Accumulable[_,_]]]()
  var lastId: Long = 0
  
  def newId: Long = synchronized {
    lastId += 1
    return lastId
  }

  def register(a: Accumulable[_,_], original: Boolean): Unit = synchronized {
    if (original) {
      originals(a.id) = a
    } else {
      val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
      accums(a.id) = a
    }
  }

  // Clear the local (non-original) accumulators for the current thread
  def clear: Unit = synchronized { 
    localAccums.remove(Thread.currentThread)
  }

  // Get the values of the local accumulators for the current thread (by ID)
  def values: Map[Long, Any] = synchronized {
    val ret = Map[Long, Any]()
    for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
      ret(id) = accum.localValue
    }
    return ret
  }

  // Add values to the original accumulators with some given IDs
  def add(values: Map[Long, Any]): Unit = synchronized {
    for ((id, value) <- values) {
      if (originals.contains(id)) {
        originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
      }
    }
  }
}