aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/scheduler/ShuffleMapTask.scala')
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala73
1 files changed, 39 insertions, 34 deletions
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index bd1911fce2..bed9f1864f 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -14,22 +14,25 @@ import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.storage._
+import util.{TimeStampedHashMap, MetadataCleaner}
private[spark] object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new JHashMap[Int, Array[Byte]]
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
- val old = serializedInfoCache.get(stageId)
+ val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
return old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
@@ -45,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@@ -78,7 +81,7 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {
- def this() = this(0, null, null, 0, null)
+ protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
null
@@ -87,13 +90,16 @@ private[spark] class ShuffleMapTask(
}
override def writeExternal(out: ObjectOutput) {
- out.writeInt(stageId)
- val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partition)
- out.writeLong(generation)
- out.writeObject(split)
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeLong(generation)
+ out.writeObject(split)
+ }
}
override def readExternal(in: ObjectInput) {
@@ -111,34 +117,33 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val partitioner = dep.partitioner
val taskContext = new TaskContext(stageId, partition, attemptId)
+ try {
+ // Partition the map output.
+ val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ for (elem <- rdd.iterator(split, taskContext)) {
+ val pair = elem.asInstanceOf[(Any, Any)]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ buckets(bucketId) += pair
+ }
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
- for (elem <- rdd.iterator(split, taskContext)) {
- val pair = elem.asInstanceOf[(Any, Any)]
- val bucketId = partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
- }
- val bucketIterators = buckets.map(_.iterator)
+ val compressedSizes = new Array[Byte](numOutputSplits)
- val compressedSizes = new Array[Byte](numOutputSplits)
+ val blockManager = SparkEnv.get.blockManager
+ for (i <- 0 until numOutputSplits) {
+ val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
+ // Get a Scala iterator from Java map
+ val iter: Iterator[(Any, Any)] = buckets(i).iterator
+ val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ compressedSizes(i) = MapOutputTracker.compressSize(size)
+ }
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = bucketIterators(i)
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
- compressedSizes(i) = MapOutputTracker.compressSize(size)
+ return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } finally {
+ // Execute the callbacks on task completion.
+ taskContext.executeOnCompleteCallbacks()
}
-
- // Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
-
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
override def preferredLocations: Seq[String] = locs