diff options
-rw-r--r-- | core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 52 |
1 files changed, 29 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index eb2fdecc83..9cb6159790 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -376,8 +376,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) * @param numReducers total number of reducers in the shuffle * @param fractionThreshold fraction of total map output size that a location must have * for it to be considered large. - * - * This method is not thread-safe. */ def getLocationsWithLargestOutputs( shuffleId: Int, @@ -386,28 +384,36 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) fractionThreshold: Double) : Option[Array[BlockManagerId]] = { - if (mapStatuses.contains(shuffleId)) { - val statuses = mapStatuses(shuffleId) - if (statuses.nonEmpty) { - // HashMap to add up sizes of all blocks at the same location - val locs = new HashMap[BlockManagerId, Long] - var totalOutputSize = 0L - var mapIdx = 0 - while (mapIdx < statuses.length) { - val status = statuses(mapIdx) - val blockSize = status.getSizeForBlock(reducerId) - if (blockSize > 0) { - locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize - totalOutputSize += blockSize + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses != null) { + statuses.synchronized { + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + // status may be null here if we are called between registerShuffle, which creates an + // array with null entries for each output, and registerMapOutputs, which populates it + // with valid status entries. This is possible if one thread schedules a job which + // depends on an RDD which is currently being computed by another thread. + if (status != null) { + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + } + mapIdx = mapIdx + 1 + } + val topLocs = locs.filter { case (loc, size) => + size.toDouble / totalOutputSize >= fractionThreshold + } + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.keys.toArray) } - mapIdx = mapIdx + 1 - } - val topLocs = locs.filter { case (loc, size) => - size.toDouble / totalOutputSize >= fractionThreshold - } - // Return if we have any locations which satisfy the required threshold - if (topLocs.nonEmpty) { - return Some(topLocs.map(_._1).toArray) } } } |