diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/scheduler/Task.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/scheduler/Task.scala | 39 |
1 files changed, 30 insertions, 9 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d2b8ca90a9..1ff9d7795f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -19,12 +19,13 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.HashMap import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} @@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * @param initialAccumulators initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - val initialAccumulators: Seq[Accumulator[_]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]], + @transient var localProperties: Properties) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -71,6 +74,7 @@ private[spark] abstract class Task[T]( taskAttemptId, attemptNumber, taskMemoryManager, + localProperties, metricsSystem, initialAccumulators) TaskContext.setTaskContext(context) @@ -80,17 +84,24 @@ private[spark] abstract class Task[T]( } try { runTask(context) - } catch { case e: Throwable => - // Catch all errors; run task failure callbacks, and rethrow the exception. - context.markTaskFailed(e) - throw e + } catch { + case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + try { + context.markTaskFailed(e) + } catch { + case t: Throwable => + e.addSuppressed(t) + } + throw e } finally { // Call the task completion callbacks. context.markTaskCompleted() try { Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) // Notify any tasks waiting for execution memory to be freed to wake up and try to // acquire memory again. This makes impossible the scenario where a task sleeps forever // because there are no other tasks left to notify it. Since this is safe to do but may @@ -205,6 +216,11 @@ private[spark] object Task { dataOut.writeLong(timestamp) } + // Write the task properties separately so it is available before full task deserialization. + val propBytes = Utils.serialize(task.localProperties) + dataOut.writeInt(propBytes.length) + dataOut.write(propBytes) + // Write the task itself and finish dataOut.flush() val taskBytes = serializer.serialize(task) @@ -220,7 +236,7 @@ private[spark] object Task { * @return (taskFiles, taskJars, taskBytes) */ def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { val in = new ByteBufferInputStream(serializedTask) val dataIn = new DataInputStream(in) @@ -239,8 +255,13 @@ private[spark] object Task { taskJars(dataIn.readUTF()) = dataIn.readLong() } + val propLength = dataIn.readInt() + val propBytes = new Array[Byte](propLength) + dataIn.readFully(propBytes, 0, propLength) + val taskProps = Utils.deserialize[Properties](propBytes) + // Create a sub-buffer for the rest of the data, which is the serialized Task object val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) + (taskFiles, taskJars, taskProps, subBuffer) } } |