aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/scheduler/ResultTask.scala
blob: beb21a76fe5c8247b3373e05cb09355141d49ed8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package spark.scheduler

import spark._
import java.io._
import util.{MetadataCleaner, TimeStampedHashMap}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

private[spark] object ResultTask {

  // A simple map between the stage id to the serialized byte array of a task.
  // Served as a cache for task serialization because serialization can be
  // expensive on the master node if it needs to launch thousands of tasks.
  val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]

  val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)

  def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
    synchronized {
      val old = serializedInfoCache.get(stageId).orNull
      if (old != null) {
        return old
      } else {
        val out = new ByteArrayOutputStream
        val ser = SparkEnv.get.closureSerializer.newInstance
        val objOut = ser.serializeStream(new GZIPOutputStream(out))
        objOut.writeObject(rdd)
        objOut.writeObject(func)
        objOut.close()
        val bytes = out.toByteArray
        serializedInfoCache.put(stageId, bytes)
        return bytes
      }
    }
  }

  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
    synchronized {
      val loader = Thread.currentThread.getContextClassLoader
      val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
      val ser = SparkEnv.get.closureSerializer.newInstance
      val objIn = ser.deserializeStream(in)
      val rdd = objIn.readObject().asInstanceOf[RDD[_]]
      val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
      return (rdd, func)
    }
  }

  def clearCache() {
    synchronized {
      serializedInfoCache.clear()
    }
  }
}


private[spark] class ResultTask[T, U](
    stageId: Int,
    var rdd: RDD[T],
    var func: (TaskContext, Iterator[T]) => U,
    var partition: Int,
    @transient locs: Seq[String],
    val outputId: Int)
  extends Task[U](stageId) with Externalizable {

  def this() = this(0, null, null, 0, null, 0)

  var split = if (rdd == null) {
    null
  } else {
    rdd.partitions(partition)
  }

  override def run(attemptId: Long): U = {
    val context = new TaskContext(stageId, partition, attemptId)
    metrics = Some(context.taskMetrics)
    try {
      func(context, rdd.iterator(split, context))
    } finally {
      context.executeOnCompleteCallbacks()
    }
  }

  override def preferredLocations: Seq[String] = locs

  override def toString = "ResultTask(" + stageId + ", " + partition + ")"

  override def writeExternal(out: ObjectOutput) {
    RDDCheckpointData.synchronized {
      split = rdd.partitions(partition)
      out.writeInt(stageId)
      val bytes = ResultTask.serializeInfo(
        stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
      out.writeInt(bytes.length)
      out.write(bytes)
      out.writeInt(partition)
      out.writeInt(outputId)
      out.writeObject(split)
    }
  }

  override def readExternal(in: ObjectInput) {
    val stageId = in.readInt()
    val numBytes = in.readInt()
    val bytes = new Array[Byte](numBytes)
    in.readFully(bytes)
    val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
    rdd = rdd_.asInstanceOf[RDD[T]]
    func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
    partition = in.readInt()
    val outputId = in.readInt()
    split = in.readObject().asInstanceOf[Partition]
  }
}