aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/MapOutputTracker.scala
blob: 866d630a6d27b8b41b2965c97146d8087d9f450b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
package spark

import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet

import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._

import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
import spark.util.{MetadataCleaner, TimeStampedHashMap}


private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
  extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage

private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
  def receive = {
    case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
      sender ! tracker.getSerializedLocations(shuffleId)

    case StopMapOutputTracker =>
      logInfo("MapOutputTrackerActor stopped!")
      sender ! true
      context.stop(self)
  }
}

private[spark] class MapOutputTracker extends Logging {

  // Set to the MapOutputTrackerActor living on the driver
  var trackerActor: ActorRef = _

  var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]

  // Incremented every time a fetch fails so that client nodes know to clear
  // their cache of map output locations if this happens.
  private var generation: Long = 0
  private val generationLock = new java.lang.Object

  // Cache a serialized version of the output statuses for each shuffle to send them out faster
  var cacheGeneration = generation
  val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]

  val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)

  // Send a message to the trackerActor and get its result within a default timeout, or
  // throw a SparkException if this fails.
  def askTracker(message: Any): Any = {
    try {
      val timeout = 10.seconds
      val future = trackerActor.ask(message)(timeout)
      return Await.result(future, timeout)
    } catch {
      case e: Exception =>
        throw new SparkException("Error communicating with MapOutputTracker", e)
    }
  }

  // Send a one-way message to the trackerActor, to which we expect it to reply with true.
  def communicate(message: Any) {
    if (askTracker(message) != true) {
      throw new SparkException("Error reply received from MapOutputTracker")
    }
  }

  def registerShuffle(shuffleId: Int, numMaps: Int) {
    if (mapStatuses.get(shuffleId) != None) {
      throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
    }
    mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
  }

  def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
    var array = mapStatuses(shuffleId)
    array.synchronized {
      array(mapId) = status
    }
  }

  def registerMapOutputs(
      shuffleId: Int,
      statuses: Array[MapStatus],
      changeGeneration: Boolean = false) {
    mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
    if (changeGeneration) {
      incrementGeneration()
    }
  }

  def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
    var array = mapStatuses(shuffleId)
    if (array != null) {
      array.synchronized {
        if (array(mapId) != null && array(mapId).location == bmAddress) {
          array(mapId) = null
        }
      }
      incrementGeneration()
    } else {
      throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
    }
  }

  // Remembers which map output locations are currently being fetched on a worker
  val fetching = new HashSet[Int]

  // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
  def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      fetching.synchronized {
        if (fetching.contains(shuffleId)) {
          // Someone else is fetching it; wait for them to be done
          while (fetching.contains(shuffleId)) {
            try {
              fetching.wait()
            } catch {
              case e: InterruptedException =>
            }
          }
          return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
        } else {
          fetching += shuffleId
        }
      }
      // We won the race to fetch the output locs; do so
      logInfo("Doing the fetch; tracker actor = " + trackerActor)
      val host = System.getProperty("spark.hostname", Utils.localHostName)
      // This try-finally prevents hangs due to timeouts:
      var fetchedStatuses: Array[MapStatus] = null
      try {
        val fetchedBytes =
          askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
        fetchedStatuses = deserializeStatuses(fetchedBytes)
        logInfo("Got the output locations")
        mapStatuses.put(shuffleId, fetchedStatuses)
      } finally {
        fetching.synchronized {
          fetching -= shuffleId
          fetching.notifyAll()
        }
      }
      return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
    } else {
      return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
    }
  }

  private def cleanup(cleanupTime: Long) {
    mapStatuses.clearOldValues(cleanupTime)
    cachedSerializedStatuses.clearOldValues(cleanupTime)
  }

  def stop() {
    communicate(StopMapOutputTracker)
    mapStatuses.clear()
    metadataCleaner.cancel()
    trackerActor = null
  }

  // Called on master to increment the generation number
  def incrementGeneration() {
    generationLock.synchronized {
      generation += 1
      logDebug("Increasing generation to " + generation)
    }
  }

  // Called on master or workers to get current generation number
  def getGeneration: Long = {
    generationLock.synchronized {
      return generation
    }
  }

  // Called on workers to update the generation number, potentially clearing old outputs
  // because of a fetch failure. (Each Mesos task calls this with the latest generation
  // number on the master at the time it was created.)
  def updateGeneration(newGen: Long) {
    generationLock.synchronized {
      if (newGen > generation) {
        logInfo("Updating generation to " + newGen + " and clearing cache")
        mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
        generation = newGen
      }
    }
  }

  def getSerializedLocations(shuffleId: Int): Array[Byte] = {
    var statuses: Array[MapStatus] = null
    var generationGotten: Long = -1
    generationLock.synchronized {
      if (generation > cacheGeneration) {
        cachedSerializedStatuses.clear()
        cacheGeneration = generation
      }
      cachedSerializedStatuses.get(shuffleId) match {
        case Some(bytes) =>
          return bytes
        case None =>
          statuses = mapStatuses(shuffleId)
          generationGotten = generation
      }
    }
    // If we got here, we failed to find the serialized locations in the cache, so we pulled
    // out a snapshot of the locations as "locs"; let's serialize and return that
    val bytes = serializeStatuses(statuses)
    logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
    // Add them into the table only if the generation hasn't changed while we were working
    generationLock.synchronized {
      if (generation == generationGotten) {
        cachedSerializedStatuses(shuffleId) = bytes
      }
    }
    return bytes
  }

  // Serialize an array of map output locations into an efficient byte format so that we can send
  // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
  // generally be pretty compressible because many map outputs will be on the same hostname.
  def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
    val out = new ByteArrayOutputStream
    val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
    objOut.writeObject(statuses)
    objOut.close()
    out.toByteArray
  }

  // Opposite of serializeStatuses.
  def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
    val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
    objIn.readObject().asInstanceOf[Array[MapStatus]]
  }
}

private[spark] object MapOutputTracker {
  private val LOG_BASE = 1.1

  // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
  // any of the statuses is null (indicating a missing location due to a failed mapper),
  // throw a FetchFailedException.
  def convertMapStatuses(
        shuffleId: Int,
        reduceId: Int,
        statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
    if (statuses == null) {
      throw new FetchFailedException(null, shuffleId, -1, reduceId,
        new Exception("Missing all output locations for shuffle " + shuffleId))
    }
    statuses.map {
      status => 
        if (status == null) {
          throw new FetchFailedException(null, shuffleId, -1, reduceId,
            new Exception("Missing an output location for shuffle " + shuffleId))
        } else {
          (status.location, decompressSize(status.compressedSizes(reduceId)))
        }
    }
  }

  /**
   * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
   * We do this by encoding the log base 1.1 of the size as an integer, which can support
   * sizes up to 35 GB with at most 10% error.
   */
  def compressSize(size: Long): Byte = {
    if (size == 0) {
      0
    } else if (size <= 1L) {
      1
    } else {
      math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
    }
  }

  /**
   * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
   */
  def decompressSize(compressedSize: Byte): Long = {
    if (compressedSize == 0) {
      0
    } else {
      math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
    }
  }
}