diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala new file mode 100644 index 0000000000..2f157ccdd2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark._ +import java.io._ +import util.{MetadataCleaner, TimeStampedHashMap} +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +private[spark] object ResultTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) + + def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(func) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] + return (rdd, func) + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + + +private[spark] class ResultTask[T, U]( + stageId: Int, + var rdd: RDD[T], + var func: (TaskContext, Iterator[T]) => U, + var partition: Int, + @transient locs: Seq[TaskLocation], + val outputId: Int) + extends Task[U](stageId) with Externalizable { + + def this() = this(0, null, null, 0, null, 0) + + var split = if (rdd == null) { + null + } else { + rdd.partitions(partition) + } + + @transient private val preferredLocs: Seq[TaskLocation] = { + if (locs == null) Nil else locs.toSet.toSeq + } + + override def run(attemptId: Long): U = { + val context = new TaskContext(stageId, partition, attemptId) + metrics = Some(context.taskMetrics) + try { + func(context, rdd.iterator(split, context)) + } finally { + context.executeOnCompleteCallbacks() + } + } + + override def preferredLocations: Seq[TaskLocation] = preferredLocs + + override def toString = "ResultTask(" + stageId + ", " + partition + ")" + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.partitions(partition) + out.writeInt(stageId) + val bytes = ResultTask.serializeInfo( + stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeInt(outputId) + out.writeLong(epoch) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) + rdd = rdd_.asInstanceOf[RDD[T]] + func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] + partition = in.readInt() + val outputId = in.readInt() + epoch = in.readLong() + split = in.readObject().asInstanceOf[Partition] + } +} |