diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/MapOutputTracker.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5e465fa22c..77b8ca1cce 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,17 +21,15 @@ import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashSet +import scala.concurrent.Await +import scala.concurrent.duration._ import akka.actor._ -import akka.dispatch._ import akka.pattern.ask -import akka.util.Duration - import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap} - +import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) @@ -52,10 +50,10 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } } -private[spark] class MapOutputTracker extends Logging { +private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { + + private val timeout = AkkaUtils.askTimeout(conf) - private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ @@ -67,14 +65,14 @@ private[spark] class MapOutputTracker extends Logging { protected val epochLock = new java.lang.Object private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup) + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. private def askTracker(message: Any): Any = { try { val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) + Await.result(future, timeout) } catch { case e: Exception => throw new SparkException("Error communicating with MapOutputTracker", e) @@ -117,11 +115,11 @@ private[spark] class MapOutputTracker extends Logging { fetching += shuffleId } } - + if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val hostPort = Utils.localHostPort() + val hostPort = Utils.localHostPort(conf) // This try-finally prevents hangs due to timeouts: try { val fetchedBytes = @@ -144,7 +142,7 @@ private[spark] class MapOutputTracker extends Logging { else{ throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) - } + } } else { statuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) @@ -184,7 +182,8 @@ private[spark] class MapOutputTracker extends Logging { } } -private[spark] class MapOutputTrackerMaster extends MapOutputTracker { +private[spark] class MapOutputTrackerMaster(conf: SparkConf) + extends MapOutputTracker(conf) { // Cache a serialized version of the output statuses for each shuffle to send them out faster private var cacheEpoch = epoch @@ -244,12 +243,12 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { case Some(bytes) => return bytes case None => - statuses = mapStatuses(shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) epochGotten = epoch } } // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "locs"; let's serialize and return that + // out a snapshot of the locations as "statuses"; let's serialize and return that val bytes = MapOutputTracker.serializeMapStatuses(statuses) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working @@ -274,6 +273,10 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { override def updateEpoch(newEpoch: Long) { // This might be called on the MapOutputTrackerMaster if we're running in local mode. } + + def has(shuffleId: Int): Boolean = { + cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) + } } private[spark] object MapOutputTracker { @@ -308,7 +311,7 @@ private[spark] object MapOutputTracker { statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { assert (statuses != null) statuses.map { - status => + status => if (status == null) { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) |