aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/broadcast/MultiTracker.scala
blob: 3fd77af73f17624543f3cca61b3f8900cae040fb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
package spark.broadcast

import java.io._
import java.net._
import java.util.Random

import scala.collection.mutable.Map

import spark._

private object MultiTracker
extends Logging {

  // Tracker Messages
  val REGISTER_BROADCAST_TRACKER = 0
  val UNREGISTER_BROADCAST_TRACKER = 1
  val FIND_BROADCAST_TRACKER = 2

  // Map to keep track of guides of ongoing broadcasts
  var valueToGuideMap = Map[Long, SourceInfo]()

  // Random number generator
  var ranGen = new Random

  private var initialized = false
  private var _isDriver = false

  private var stopBroadcast = false

  private var trackMV: TrackMultipleValues = null

  def initialize(__isDriver: Boolean) {
    synchronized {
      if (!initialized) {
        _isDriver = __isDriver

        if (isDriver) {
          trackMV = new TrackMultipleValues
          trackMV.setDaemon(true)
          trackMV.start()
        
          // Set DriverHostAddress to the driver's IP address for the slaves to read
          System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
        }

        initialized = true
      }
    }
  }
  
  def stop() {
    stopBroadcast = true
  }

  // Load common parameters
  private var DriverHostAddress_ = System.getProperty(
    "spark.MultiTracker.DriverHostAddress", "")
  private var DriverTrackerPort_ = System.getProperty(
    "spark.broadcast.driverTrackerPort", "11111").toInt
  private var BlockSize_ = System.getProperty(
    "spark.broadcast.blockSize", "4096").toInt * 1024
  private var MaxRetryCount_ = System.getProperty(
    "spark.broadcast.maxRetryCount", "2").toInt

  private var TrackerSocketTimeout_ = System.getProperty(
    "spark.broadcast.trackerSocketTimeout", "50000").toInt
  private var ServerSocketTimeout_ = System.getProperty(
    "spark.broadcast.serverSocketTimeout", "10000").toInt

  private var MinKnockInterval_ = System.getProperty(
    "spark.broadcast.minKnockInterval", "500").toInt
  private var MaxKnockInterval_ = System.getProperty(
    "spark.broadcast.maxKnockInterval", "999").toInt

  // Load TreeBroadcast config params
  private var MaxDegree_ = System.getProperty(
    "spark.broadcast.maxDegree", "2").toInt

  // Load BitTorrentBroadcast config params
  private var MaxPeersInGuideResponse_ = System.getProperty(
    "spark.broadcast.maxPeersInGuideResponse", "4").toInt

  private var MaxChatSlots_ = System.getProperty(
    "spark.broadcast.maxChatSlots", "4").toInt
  private var MaxChatTime_ = System.getProperty(
    "spark.broadcast.maxChatTime", "500").toInt
  private var MaxChatBlocks_ = System.getProperty(
    "spark.broadcast.maxChatBlocks", "1024").toInt

  private var EndGameFraction_ = System.getProperty(
      "spark.broadcast.endGameFraction", "0.95").toDouble

  def isDriver = _isDriver

  // Common config params
  def DriverHostAddress = DriverHostAddress_
  def DriverTrackerPort = DriverTrackerPort_
  def BlockSize = BlockSize_
  def MaxRetryCount = MaxRetryCount_

  def TrackerSocketTimeout = TrackerSocketTimeout_
  def ServerSocketTimeout = ServerSocketTimeout_

  def MinKnockInterval = MinKnockInterval_
  def MaxKnockInterval = MaxKnockInterval_

  // TreeBroadcast configs
  def MaxDegree = MaxDegree_

  // BitTorrentBroadcast configs
  def MaxPeersInGuideResponse = MaxPeersInGuideResponse_

  def MaxChatSlots = MaxChatSlots_
  def MaxChatTime = MaxChatTime_
  def MaxChatBlocks = MaxChatBlocks_

  def EndGameFraction = EndGameFraction_

  class TrackMultipleValues
  extends Thread with Logging {
    override def run() {
      var threadPool = Utils.newDaemonCachedThreadPool()
      var serverSocket: ServerSocket = null

      serverSocket = new ServerSocket(DriverTrackerPort)
      logInfo("TrackMultipleValues started at " + serverSocket)

      try {
        while (!stopBroadcast) {
          var clientSocket: Socket = null
          try {
            serverSocket.setSoTimeout(TrackerSocketTimeout)
            clientSocket = serverSocket.accept()
          } catch {
            case e: Exception => {
              if (stopBroadcast) {
                logInfo("Stopping TrackMultipleValues...")
              }              
            }
          }

          if (clientSocket != null) {
            try {
              threadPool.execute(new Thread {
                override def run() {
                  val oos = new ObjectOutputStream(clientSocket.getOutputStream)
                  oos.flush()
                  val ois = new ObjectInputStream(clientSocket.getInputStream)

                  try {
                    // First, read message type
                    val messageType = ois.readObject.asInstanceOf[Int]

                    if (messageType == REGISTER_BROADCAST_TRACKER) {
                      // Receive Long
                      val id = ois.readObject.asInstanceOf[Long]
                      // Receive hostAddress and listenPort
                      val gInfo = ois.readObject.asInstanceOf[SourceInfo]

                      // Add to the map
                      valueToGuideMap.synchronized {
                        valueToGuideMap += (id -> gInfo)
                      }

                      logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)

                      // Send dummy ACK
                      oos.writeObject(-1)
                      oos.flush()
                    } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
                      // Receive Long
                      val id = ois.readObject.asInstanceOf[Long]

                      // Remove from the map
                      valueToGuideMap.synchronized {
                        valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
                      }

                      logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)

                      // Send dummy ACK
                      oos.writeObject(-1)
                      oos.flush()
                    } else if (messageType == FIND_BROADCAST_TRACKER) {
                      // Receive Long
                      val id = ois.readObject.asInstanceOf[Long]

                      var gInfo =
                        if (valueToGuideMap.contains(id)) valueToGuideMap(id)
                        else SourceInfo("", SourceInfo.TxNotStartedRetry)

                      logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)

                      // Send reply back
                      oos.writeObject(gInfo)
                      oos.flush()
                    } else {
                      throw new SparkException("Undefined messageType at TrackMultipleValues")
                    }
                  } catch {
                    case e: Exception => {
                      logError("TrackMultipleValues had a " + e)
                    }
                  } finally {
                    ois.close()
                    oos.close()
                    clientSocket.close()
                  }
                }
              })
            } catch {
              // In failure, close socket here; else, client thread will close
              case ioe: IOException => clientSocket.close()
            }
          }
        }
      } finally {
        serverSocket.close()
      }
      // Shutdown the thread pool
      threadPool.shutdown()
    }
  }
  
  def getGuideInfo(variableLong: Long): SourceInfo = {
    var clientSocketToTracker: Socket = null
    var oosTracker: ObjectOutputStream = null
    var oisTracker: ObjectInputStream = null

    var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)

    var retriesLeft = MultiTracker.MaxRetryCount
    do {
      try {
        // Connect to the tracker to find out GuideInfo
        clientSocketToTracker =
          new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
        oosTracker =
          new ObjectOutputStream(clientSocketToTracker.getOutputStream)
        oosTracker.flush()
        oisTracker =
          new ObjectInputStream(clientSocketToTracker.getInputStream)

        // Send messageType/intention
        oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
        oosTracker.flush()

        // Send Long and receive GuideInfo
        oosTracker.writeObject(variableLong)
        oosTracker.flush()
        gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
      } catch {
        case e: Exception => logError("getGuideInfo had a " + e)
      } finally {
        if (oisTracker != null) {
          oisTracker.close()
        }
        if (oosTracker != null) {
          oosTracker.close()
        }
        if (clientSocketToTracker != null) {
          clientSocketToTracker.close()
        }
      }

      Thread.sleep(MultiTracker.ranGen.nextInt(
        MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
        MultiTracker.MinKnockInterval)

      retriesLeft -= 1
    } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)

    logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
    return gInfo
  }
  
  def registerBroadcast(id: Long, gInfo: SourceInfo) {
    val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
    val oosST = new ObjectOutputStream(socket.getOutputStream)
    oosST.flush()
    val oisST = new ObjectInputStream(socket.getInputStream)

    // Send messageType/intention
    oosST.writeObject(REGISTER_BROADCAST_TRACKER)
    oosST.flush()

    // Send Long of this broadcast
    oosST.writeObject(id)
    oosST.flush()

    // Send this tracker's information
    oosST.writeObject(gInfo)
    oosST.flush()

    // Receive ACK and throw it away
    oisST.readObject.asInstanceOf[Int]

    // Shut stuff down
    oisST.close()
    oosST.close()
    socket.close()
  }

  def unregisterBroadcast(id: Long) {
    val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
    val oosST = new ObjectOutputStream(socket.getOutputStream)
    oosST.flush()
    val oisST = new ObjectInputStream(socket.getInputStream)

    // Send messageType/intention
    oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
    oosST.flush()

    // Send Long of this broadcast
    oosST.writeObject(id)
    oosST.flush()

    // Receive ACK and throw it away
    oisST.readObject.asInstanceOf[Int]

    // Shut stuff down
    oisST.close()
    oosST.close()
    socket.close()
  }

  // Helper method to convert an object to Array[BroadcastBlock]
  def blockifyObject[IN](obj: IN): VariableInfo = {
    val baos = new ByteArrayOutputStream
    val oos = new ObjectOutputStream(baos)
    oos.writeObject(obj)
    oos.close()
    baos.close()
    val byteArray = baos.toByteArray
    val bais = new ByteArrayInputStream(byteArray)

    var blockNum = (byteArray.length / BlockSize)
    if (byteArray.length % BlockSize != 0)
      blockNum += 1

    var retVal = new Array[BroadcastBlock](blockNum)
    var blockID = 0

    for (i <- 0 until (byteArray.length, BlockSize)) {
      val thisBlockSize = math.min(BlockSize, byteArray.length - i)
      var tempByteArray = new Array[Byte](thisBlockSize)
      val hasRead = bais.read(tempByteArray, 0, thisBlockSize)

      retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
      blockID += 1
    }
    bais.close()

    var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
    variableInfo.hasBlocks = blockNum

    return variableInfo
  }

  // Helper method to convert Array[BroadcastBlock] to object
  def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
                            totalBytes: Int, 
                            totalBlocks: Int): OUT = {

    var retByteArray = new Array[Byte](totalBytes)
    for (i <- 0 until totalBlocks) {
      System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
        i * BlockSize, arrayOfBlocks(i).byteArray.length)
    }
    byteArrayToObject(retByteArray)
  }

  private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
    val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
      override def resolveClass(desc: ObjectStreamClass) =
        Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
    }    
    val retVal = in.readObject.asInstanceOf[OUT]
    in.close()
    return retVal
  }
}

private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) 
extends Serializable

private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
                        totalBlocks: Int, 
                        totalBytes: Int) 
extends Serializable {
 @transient var hasBlocks = 0 
}