aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
blob: f1eae9bc881aab81d23b99e168aad766e0f8040d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package spark.scheduler

import java.io._
import java.util.HashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._

import it.unimi.dsi.fastutil.io.FastBufferedOutputStream

import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream

import spark._
import spark.storage._

object ShuffleMapTask {
  val serializedInfoCache = new HashMap[Int, Array[Byte]]
  val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]

  def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
    synchronized {
      val old = serializedInfoCache.get(stageId)
      if (old != null) {
        return old
      } else {
        val out = new ByteArrayOutputStream
        val ser = SparkEnv.get.closureSerializer.newInstance
        val objOut = ser.serializeStream(new GZIPOutputStream(out))
        objOut.writeObject(rdd)
        objOut.writeObject(dep)
        objOut.close()
        val bytes = out.toByteArray
        serializedInfoCache.put(stageId, bytes)
        return bytes
      }
    }
  }

  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
    synchronized {
      val old = deserializedInfoCache.get(stageId)
      if (old != null) {
        return old
      } else {
        val loader = Thread.currentThread.getContextClassLoader
        val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
        val ser = SparkEnv.get.closureSerializer.newInstance
        val objIn = ser.deserializeStream(in)
        val rdd = objIn.readObject().asInstanceOf[RDD[_]]
        val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
        val tuple = (rdd, dep)
        deserializedInfoCache.put(stageId, tuple)
        return tuple
      }
    }
  }

  def clearCache() {
    synchronized {
      serializedInfoCache.clear()
      deserializedInfoCache.clear()
    }
  }
}

class ShuffleMapTask(
    stageId: Int,
    var rdd: RDD[_], 
    var dep: ShuffleDependency[_,_,_],
    var partition: Int, 
    @transient var locs: Seq[String])
  extends Task[BlockManagerId](stageId)
  with Externalizable
  with Logging {

  def this() = this(0, null, null, 0, null)
  
  var split = if (rdd == null) {
    null 
  } else { 
    rdd.splits(partition)
  }

  override def writeExternal(out: ObjectOutput) {
    out.writeInt(stageId)
    val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
    out.writeInt(bytes.length)
    out.write(bytes)
    out.writeInt(partition)
    out.writeLong(generation)
    out.writeObject(split)
  }

  override def readExternal(in: ObjectInput) {
    val stageId = in.readInt()
    val numBytes = in.readInt()
    val bytes = new Array[Byte](numBytes)
    in.readFully(bytes)
    val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
    rdd = rdd_
    dep = dep_
    partition = in.readInt()
    generation = in.readLong()
    split = in.readObject().asInstanceOf[Split]
  }

  override def run(attemptId: Long): BlockManagerId = {
    val numOutputSplits = dep.partitioner.numPartitions
    val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
    val partitioner = dep.partitioner
    val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
    for (elem <- rdd.iterator(split)) {
      val (k, v) = elem.asInstanceOf[(Any, Any)]
      var bucketId = partitioner.getPartition(k)
      val bucket = buckets(bucketId)
      var existing = bucket.get(k)
      if (existing == null) {
        bucket.put(k, aggregator.createCombiner(v))
      } else {
        bucket.put(k, aggregator.mergeValue(existing, v))
      }
    }
    val ser = SparkEnv.get.serializer.newInstance()
    val blockManager = SparkEnv.get.blockManager
    for (i <- 0 until numOutputSplits) {
      val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
      // Get a scala iterator from java map
      val iter: Iterator[(Any, Any)] = buckets(i).iterator
      // TODO: This should probably be DISK_ONLY
      blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false)
    }
    return SparkEnv.get.blockManager.blockManagerId
  }

  override def preferredLocations: Seq[String] = locs

  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}