aboutsummaryrefslogtreecommitdiff
path: root/src/scala/spark/Broadcast.scala
blob: f59912b4c86e08d38994de2d86d609c893e5d693 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package spark

import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}

@serializable
trait Broadcast[T] {
  val uuid = UUID.randomUUID

  def value: T

  // We cannot have an abstract readObject here due to some weird issues with 
  // readObject having to be 'private' in sub-classes. Possibly a Scala bug!

  def sendBroadcast: Unit

  override def toString = "spark.Broadcast(" + uuid + ")"
}

trait BroadcastFactory {
  def initialize (isMaster: Boolean): Unit
  def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
}

private object Broadcast
extends Logging {
  private var initialized = false 
  private var broadcastFactory: BroadcastFactory = null

  // Called by SparkContext or Executor before using Broadcast
  def initialize (isMaster: Boolean): Unit = synchronized {
    if (!initialized) {
      val broadcastFactoryClass = System.getProperty("spark.broadcast.Factory",
        "spark.BitTorrentBroadcastFactory")
      val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef])
//      broadcastFactory = Class.forName(broadcastFactoryClass).getConstructors()(0).newInstance(booleanArgs:_*).asInstanceOf[BroadcastFactory]
      broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
      
      // Initialize appropriate BroadcastFactory and BroadcastObject
      broadcastFactory.initialize(isMaster)
      
      initialized = true
    }
  }
  
  def getBroadcastFactory: BroadcastFactory = {
    if (broadcastFactory == null) {
      throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
    }
    broadcastFactory
  }
  
  // Returns a standard ThreadFactory except all threads are daemons
  private def newDaemonThreadFactory: ThreadFactory = {
    new ThreadFactory {
      def newThread(r: Runnable): Thread = {
        var t = Executors.defaultThreadFactory.newThread (r)
        t.setDaemon (true)
        return t
      }
    }  
  }
  
  // Wrapper over newCachedThreadPool
  def newDaemonCachedThreadPool: ThreadPoolExecutor = {
    var threadPool = 
      Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
  
    threadPool.setThreadFactory (newDaemonThreadFactory)
    
    return threadPool
  }
  
  // Wrapper over newFixedThreadPool
  def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
    var threadPool = 
      Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
  
    threadPool.setThreadFactory (newDaemonThreadFactory)
    
    return threadPool
  }  
}

@serializable
case class SourceInfo (val hostAddress: String, val listenPort: Int, 
  val totalBlocks: Int, val totalBytes: Int)  
extends Comparable[SourceInfo] with Logging {

  var currentLeechers = 0
  var receptionFailed = false
  
  var hasBlocks = 0
  var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
  
  // Ascending sort based on leecher count
  def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)}

object SourceInfo {
  // Constants for special values of listenPort
  val TxNotStartedRetry = -1
  val TxOverGoToHDFS = 0
  // Other constants
  val StopBroadcast = -2
  val UnusedParam = 0
}

@serializable
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }

@serializable
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], 
  val totalBlocks: Int, val totalBytes: Int) {  
  @transient var hasBlocks = 0
} 

@serializable
class SpeedTracker {
  // Mapping 'source' to '(totalTime, numBlocks)'
  private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
  
  def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
    sourceToSpeedMap.synchronized {
      if (!sourceToSpeedMap.contains(srcInfo)) {
        sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
      } else {
        val tTnB = sourceToSpeedMap (srcInfo)
        sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
      }
    }
  }
  
  def getTimePerBlock (srcInfo: SourceInfo): Double = {
    sourceToSpeedMap.synchronized {
      val tTnB = sourceToSpeedMap (srcInfo)
      return tTnB._1 / tTnB._2
    }    
  }
  
  override def toString = sourceToSpeedMap.toString
}