aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/broadcast/Broadcast.scala
blob: f39fb9de69eb88981d7236d8246f8f0a6344e760 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
package spark.broadcast

import java.io._
import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}

import spark._

@serializable
trait Broadcast[T] {
  val uuid = UUID.randomUUID

  def value: T

  // We cannot have an abstract readObject here due to some weird issues with
  // readObject having to be 'private' in sub-classes. Possibly a Scala bug!

  override def toString = "spark.Broadcast(" + uuid + ")"
}

object Broadcast
extends Logging {
  // Messages
  val REGISTER_BROADCAST_TRACKER = 0
  val UNREGISTER_BROADCAST_TRACKER = 1
  val FIND_BROADCAST_TRACKER = 2
  val GET_UPDATED_SHARE = 3

  private var initialized = false
  private var isMaster_ = false
  private var broadcastFactory: BroadcastFactory = null

  // Called by SparkContext or Executor before using Broadcast
  def initialize (isMaster__ : Boolean): Unit = synchronized {
    if (!initialized) {
      val broadcastFactoryClass = System.getProperty(
        "spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")

      broadcastFactory =
        Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

      // Setup isMaster before using it
      isMaster_ = isMaster__
      
      // Set masterHostAddress to the master's IP address for the slaves to read
      if (isMaster) {
        System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress)
      }

      // Initialize appropriate BroadcastFactory and BroadcastObject
      broadcastFactory.initialize(isMaster)

      initialized = true
    }
  }

  def getBroadcastFactory: BroadcastFactory = {
    if (broadcastFactory == null) {
      throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
    }
    broadcastFactory
  }

  // Load common broadcast-related config parameters
  private var MasterHostAddress_ = System.getProperty(
    "spark.broadcast.masterHostAddress", "")
  private var MasterTrackerPort_ = System.getProperty(
    "spark.broadcast.masterTrackerPort", "11111").toInt
  private var BlockSize_ = System.getProperty(
    "spark.broadcast.blockSize", "4096").toInt * 1024
  private var MaxRetryCount_ = System.getProperty(
    "spark.broadcast.maxRetryCount", "2").toInt

  private var TrackerSocketTimeout_ = System.getProperty(
    "spark.broadcast.trackerSocketTimeout", "50000").toInt
  private var ServerSocketTimeout_ = System.getProperty(
    "spark.broadcast.serverSocketTimeout", "10000").toInt

  private var MinKnockInterval_ = System.getProperty(
    "spark.broadcast.minKnockInterval", "500").toInt
  private var MaxKnockInterval_ = System.getProperty(
    "spark.broadcast.maxKnockInterval", "999").toInt

  // Load ChainedBroadcast config params

  // Load TreeBroadcast config params
  private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt

  // Load BitTorrentBroadcast config params
  private var MaxPeersInGuideResponse_ = System.getProperty(
    "spark.broadcast.maxPeersInGuideResponse", "4").toInt

  private var MaxRxSlots_ = System.getProperty(
    "spark.broadcast.maxRxSlots", "4").toInt
  private var MaxTxSlots_ = System.getProperty(
    "spark.broadcast.maxTxSlots", "4").toInt

  private var MaxChatTime_ = System.getProperty(
    "spark.broadcast.maxChatTime", "500").toInt
  private var MaxChatBlocks_ = System.getProperty(
    "spark.broadcast.maxChatBlocks", "1024").toInt

  private var EndGameFraction_ = System.getProperty(
    "spark.broadcast.endGameFraction", "0.95").toDouble

  def isMaster = isMaster_

  // Common config params
  def MasterHostAddress = MasterHostAddress_
  def MasterTrackerPort = MasterTrackerPort_
  def BlockSize = BlockSize_
  def MaxRetryCount = MaxRetryCount_

  def TrackerSocketTimeout = TrackerSocketTimeout_
  def ServerSocketTimeout = ServerSocketTimeout_

  def MinKnockInterval = MinKnockInterval_
  def MaxKnockInterval = MaxKnockInterval_

  // ChainedBroadcast configs

  // TreeBroadcast configs
  def MaxDegree = MaxDegree_

  // BitTorrentBroadcast configs
  def MaxPeersInGuideResponse = MaxPeersInGuideResponse_

  def MaxRxSlots = MaxRxSlots_
  def MaxTxSlots = MaxTxSlots_

  def MaxChatTime = MaxChatTime_
  def MaxChatBlocks = MaxChatBlocks_

  def EndGameFraction = EndGameFraction_

  // Helper functions to convert an object to Array[BroadcastBlock]
  def blockifyObject[IN](obj: IN): VariableInfo = {
    val baos = new ByteArrayOutputStream
    val oos = new ObjectOutputStream(baos)
    oos.writeObject(obj)
    oos.close()
    baos.close()
    val byteArray = baos.toByteArray
    val bais = new ByteArrayInputStream(byteArray)

    var blockNum = (byteArray.length / Broadcast.BlockSize)
    if (byteArray.length % Broadcast.BlockSize != 0)
      blockNum += 1

    var retVal = new Array[BroadcastBlock](blockNum)
    var blockID = 0

    for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) {
      val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i)
      var tempByteArray = new Array[Byte](thisBlockSize)
      val hasRead = bais.read(tempByteArray, 0, thisBlockSize)

      retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
      blockID += 1
    }
    bais.close()

    var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
    variableInfo.hasBlocks = blockNum

    return variableInfo
  }

  // Helper function to convert Array[BroadcastBlock] to object
  def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], 
                            totalBytes: Int, 
                            totalBlocks: Int): OUT = {

    var retByteArray = new Array[Byte](totalBytes)
    for (i <- 0 until totalBlocks) {
      System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
        i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
    }
    byteArrayToObject(retByteArray)
  }

  private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
    val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
      override def resolveClass(desc: ObjectStreamClass) =
        Class.forName(desc.getName, false, currentThread.getContextClassLoader)
    }    
    val retVal = in.readObject.asInstanceOf[OUT]
    in.close()
    return retVal
  }  
}

@serializable
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }

@serializable
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
                                    val totalBlocks: Int, 
                                    val totalBytes: Int) {
  @transient var hasBlocks = 0
}

@serializable
class SpeedTracker {
  // Mapping 'source' to '(totalTime, numBlocks)'
  private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()

  def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
    sourceToSpeedMap.synchronized {
      if (!sourceToSpeedMap.contains(srcInfo)) {
        sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
      } else {
        val tTnB = sourceToSpeedMap (srcInfo)
        sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
      }
    }
  }

  def getTimePerBlock (srcInfo: SourceInfo): Double = {
    sourceToSpeedMap.synchronized {
      val tTnB = sourceToSpeedMap (srcInfo)
      return tTnB._1 / tTnB._2
    }
  }

  override def toString = sourceToSpeedMap.toString
}