diff options
Diffstat (limited to 'core/src/main/scala/spark/scheduler/ShuffleMapTask.scala')
-rw-r--r-- | core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 62 |
1 files changed, 44 insertions, 18 deletions
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 36d087a4d0..95647389c3 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -13,9 +13,10 @@ import com.ning.compress.lzf.LZFInputStream import com.ning.compress.lzf.LZFOutputStream import spark._ -import executor.ShuffleWriteMetrics +import spark.executor.ShuffleWriteMetrics import spark.storage._ -import util.{TimeStampedHashMap, MetadataCleaner} +import spark.util.{TimeStampedHashMap, MetadataCleaner} + private[spark] object ShuffleMapTask { @@ -77,13 +78,20 @@ private[spark] class ShuffleMapTask( var rdd: RDD[_], var dep: ShuffleDependency[_,_], var partition: Int, - @transient var locs: Seq[String]) + @transient private var locs: Seq[String]) extends Task[MapStatus](stageId) with Externalizable with Logging { protected def this() = this(0, null, null, 0, null) + @transient private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs)) + } + var split = if (rdd == null) { null } else { @@ -121,40 +129,58 @@ private[spark] class ShuffleMapTask( val taskContext = new TaskContext(stageId, partition, attemptId) metrics = Some(taskContext.taskMetrics) + + val blockManager = SparkEnv.get.blockManager + var shuffle: ShuffleBlocks = null + var buckets: ShuffleWriterGroup = null + try { - // Partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + // Obtain all the block writers for shuffle blocks. + val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) + shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) + buckets = shuffle.acquireWriters(partition) + + // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = dep.partitioner.getPartition(pair._1) - buckets(bucketId) += pair + buckets.writers(bucketId).write(pair) } - val compressedSizes = new Array[Byte](numOutputSplits) - - var totalBytes = 0l - - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i - // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = buckets(i).iterator - val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + // Commit the writes. Get the size of each bucket block (total block size). + var totalBytes = 0L + val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.size() totalBytes += size - compressedSizes(i) = MapOutputTracker.compressSize(size) + MapOutputTracker.compressSize(size) } + + // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) return new MapStatus(blockManager.blockManagerId, compressedSizes) + } catch { case e: Exception => + // If there is an exception from running the task, revert the partial writes + // and throw the exception upstream to Spark. + if (buckets != null) { + buckets.writers.foreach(_.revertPartialWrites()) + } + throw e } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && buckets != null) { + shuffle.releaseWriters(buckets) + } // Execute the callbacks on task completion. taskContext.executeOnCompleteCallbacks() } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) } |