aboutsummaryrefslogblamecommitdiff
path: root/core/src/main/scala/spark/Accumulators.scala
blob: 57c6df35bed9f8fed4964b59e65d39c441d29dc6 (plain) (tree)
1
2
3
4
5
6
7
8
9
10




                                   
                                        
 
   

                                                                                               
  



                                                                                                

                                                   

                                                                              




                                              
                        
  
                             
                                                                         
                                                                            

                          
                                   
 
     
                                                    
                                
     
                                                                  

     





                                                                  
                                           
    
                                                                                    
                                                             
     
                                                              

     







                                                                                     

                                                                    





                                                                                     

   
     
                                                                  







                                                                                      
 


                                                         

                                        
                                                                                          
   






                                                        


                                                 
                          
                 
                       
                                      




                                         
   

                                                                                          
  
                                                    
                                              
   
                                                   
     


                                                                                      

                                                     

                                             
                                   

     


                                                                                           

                                              
                                           
     
                                 
 



                                                                                             
                              

 
              


                                                                                             
                                                 



                    
                                     













                                                                                               







                                                                                               

                                                                            


                                                                                                          


                                                                                                      


                                                          
                                         



                      

                                                                        
                             
                                                                             

                                                               
                      
  



                                  
 
                                                                              




                                                                           
     

   
                                                                       



                                              

   
                                                                            

                                             
                                                                             
                                
     


              
                                                                



                                                                   



       
package spark

import java.io._

import scala.collection.mutable.Map
import scala.collection.generic.Growable

/**
 * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation,
 * but where the result type, `R`, may be different from the element type being added, `T`.
 *
 * You must define how to add data, and how to merge two of these together.  For some datatypes,
 * such as a counter, these might be the same operation. In that case, you can use the simpler
 * [[spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are
 * accumulating a set. You will add items to the set, and you will union two sets together.
 *
 * @param initialValue initial value of accumulator
 * @param param helper object defining how to add elements of type `R` and `T`
 * @tparam R the full accumulated data (result type)
 * @tparam T partial data that can be added in
 */
class Accumulable[R, T] (
    @transient initialValue: R,
    param: AccumulableParam[R, T])
  extends Serializable {
  
  val id = Accumulators.newId
  @transient private var value_ = initialValue // Current value on master
  val zero = param.zero(initialValue)  // Zero value to be passed to workers
  var deserialized = false

  Accumulators.register(this, true)

  /**
   * Add more data to this accumulator / accumulable
   * @param term the data to add
   */
  def += (term: T) { value_ = param.addAccumulator(value_, term) }

  /**
   * Add more data to this accumulator / accumulable
   * @param term the data to add
   */
  def add(term: T) { value_ = param.addAccumulator(value_, term) }

  /**
   * Merge two accumulable objects together
   *
   * Normally, a user will not want to use this version, but will instead call `+=`.
   * @param term the other `R` that will get merged with this
   */
  def ++= (term: R) { value_ = param.addInPlace(value_, term)}

  /**
   * Merge two accumulable objects together
   *
   * Normally, a user will not want to use this version, but will instead call `add`.
   * @param term the other `R` that will get merged with this
   */
  def merge(term: R) { value_ = param.addInPlace(value_, term)}

  /**
   * Access the accumulator's current value; only allowed on master.
   */
  def value: R = {
    if (!deserialized) {
      value_
    } else {
      throw new UnsupportedOperationException("Can't read accumulator value in task")
    }
  }

  /**
   * Get the current value of this accumulator from within a task.
   *
   * This is NOT the global value of the accumulator.  To get the global value after a
   * completed operation on the dataset, call `value`.
   *
   * The typical use of this method is to directly mutate the local value, eg., to add
   * an element to a Set.
   */
  def localValue = value_

  /**
   * Set the accumulator's value; only allowed on master.
   */
  def value_= (newValue: R) {
    if (!deserialized) value_ = newValue
    else throw new UnsupportedOperationException("Can't assign accumulator value in task")
  }

  /**
   * Set the accumulator's value; only allowed on master
   */
  def setValue(newValue: R) {
    this.value = newValue
  }
 
  // Called by Java when deserializing an object
  private def readObject(in: ObjectInputStream) {
    in.defaultReadObject()
    value_ = zero
    deserialized = true
    Accumulators.register(this, false)
  }

  override def toString = value_.toString
}

/**
 * Helper object defining how to accumulate values of a particular type. An implicit
 * AccumulableParam needs to be available when you create Accumulables of a specific type.
 *
 * @tparam R the full accumulated data (result type)
 * @tparam T partial data that can be added in
 */
trait AccumulableParam[R, T] extends Serializable {
  /**
   * Add additional data to the accumulator value. Is allowed to modify and return `r`
   * for efficiency (to avoid allocating objects).
   *
   * @param r the current value of the accumulator
   * @param t the data to be added to the accumulator
   * @return the new value of the accumulator
   */
  def addAccumulator(r: R, t: T): R

  /**
   * Merge two accumulated values together. Is allowed to modify and return the first value
   * for efficiency (to avoid allocating objects).
   *
   * @param r1 one set of accumulated data
   * @param r2 another set of accumulated data
   * @return both data sets merged together
   */
  def addInPlace(r1: R, r2: R): R

  /**
   * Return the "zero" (identity) value for an accumulator type, given its initial value. For
   * example, if R was a vector of N dimensions, this would return a vector of N zeroes.
   */
  def zero(initialValue: R): R
}

private[spark]
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
  extends AccumulableParam[R,T] {

  def addAccumulator(growable: R, elem: T): R = {
    growable += elem
    growable
  }

  def addInPlace(t1: R, t2: R): R = {
    t1 ++= t2
    t1
  }

  def zero(initialValue: R): R = {
    // We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
    // Instead we'll serialize it to a buffer and load it back.
    val ser = (new spark.JavaSerializer).newInstance()
    val copy = ser.deserialize[R](ser.serialize(initialValue))
    copy.clear()   // In case it contained stuff
    copy
  }
}

/**
 * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
 * as the types of elements being merged.
 *
 * @param initialValue initial value of accumulator
 * @param param helper object defining how to add elements of type `T`
 * @tparam T result type
 */
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T])
  extends Accumulable[T,T](initialValue, param)

/**
 * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type
 * as the accumulated value. An implicit AccumulatorParam object needs to be available when you create
 * Accumulators of a specific type.
 *
 * @tparam T type of value to accumulate
 */
trait AccumulatorParam[T] extends AccumulableParam[T, T] {
  def addAccumulator(t1: T, t2: T): T = {
    addInPlace(t1, t2)
  }
}

// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators {
  // TODO: Use soft references? => need to make readObject work properly then
  val originals = Map[Long, Accumulable[_, _]]()
  val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
  var lastId: Long = 0
  
  def newId: Long = synchronized {
    lastId += 1
    return lastId
  }

  def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
    if (original) {
      originals(a.id) = a
    } else {
      val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
      accums(a.id) = a
    }
  }

  // Clear the local (non-original) accumulators for the current thread
  def clear() {
    synchronized {
      localAccums.remove(Thread.currentThread)
    }
  }

  // Get the values of the local accumulators for the current thread (by ID)
  def values: Map[Long, Any] = synchronized {
    val ret = Map[Long, Any]()
    for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
      ret(id) = accum.localValue
    }
    return ret
  }

  // Add values to the original accumulators with some given IDs
  def add(values: Map[Long, Any]): Unit = synchronized {
    for ((id, value) <- values) {
      if (originals.contains(id)) {
        originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
      }
    }
  }
}