aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala52
1 files changed, 29 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index eb2fdecc83..9cb6159790 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -376,8 +376,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
* @param numReducers total number of reducers in the shuffle
* @param fractionThreshold fraction of total map output size that a location must have
* for it to be considered large.
- *
- * This method is not thread-safe.
*/
def getLocationsWithLargestOutputs(
shuffleId: Int,
@@ -386,28 +384,36 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
fractionThreshold: Double)
: Option[Array[BlockManagerId]] = {
- if (mapStatuses.contains(shuffleId)) {
- val statuses = mapStatuses(shuffleId)
- if (statuses.nonEmpty) {
- // HashMap to add up sizes of all blocks at the same location
- val locs = new HashMap[BlockManagerId, Long]
- var totalOutputSize = 0L
- var mapIdx = 0
- while (mapIdx < statuses.length) {
- val status = statuses(mapIdx)
- val blockSize = status.getSizeForBlock(reducerId)
- if (blockSize > 0) {
- locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
- totalOutputSize += blockSize
+ val statuses = mapStatuses.get(shuffleId).orNull
+ if (statuses != null) {
+ statuses.synchronized {
+ if (statuses.nonEmpty) {
+ // HashMap to add up sizes of all blocks at the same location
+ val locs = new HashMap[BlockManagerId, Long]
+ var totalOutputSize = 0L
+ var mapIdx = 0
+ while (mapIdx < statuses.length) {
+ val status = statuses(mapIdx)
+ // status may be null here if we are called between registerShuffle, which creates an
+ // array with null entries for each output, and registerMapOutputs, which populates it
+ // with valid status entries. This is possible if one thread schedules a job which
+ // depends on an RDD which is currently being computed by another thread.
+ if (status != null) {
+ val blockSize = status.getSizeForBlock(reducerId)
+ if (blockSize > 0) {
+ locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
+ totalOutputSize += blockSize
+ }
+ }
+ mapIdx = mapIdx + 1
+ }
+ val topLocs = locs.filter { case (loc, size) =>
+ size.toDouble / totalOutputSize >= fractionThreshold
+ }
+ // Return if we have any locations which satisfy the required threshold
+ if (topLocs.nonEmpty) {
+ return Some(topLocs.keys.toArray)
}
- mapIdx = mapIdx + 1
- }
- val topLocs = locs.filter { case (loc, size) =>
- size.toDouble / totalOutputSize >= fractionThreshold
- }
- // Return if we have any locations which satisfy the required threshold
- if (topLocs.nonEmpty) {
- return Some(topLocs.map(_._1).toArray)
}
}
}