aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/scheduler/Task.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/scheduler/Task.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala39
1 files changed, 30 insertions, 9 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index d2b8ca90a9..1ff9d7795f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -19,12 +19,13 @@ package org.apache.spark.scheduler
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.util.Properties
import scala.collection.mutable.HashMap
import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl}
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}
@@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
* @param initialAccumulators initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
- val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
+ val initialAccumulators: Seq[Accumulator[_]],
+ @transient var localProperties: Properties) extends Serializable {
/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
@@ -71,6 +74,7 @@ private[spark] abstract class Task[T](
taskAttemptId,
attemptNumber,
taskMemoryManager,
+ localProperties,
metricsSystem,
initialAccumulators)
TaskContext.setTaskContext(context)
@@ -80,17 +84,24 @@ private[spark] abstract class Task[T](
}
try {
runTask(context)
- } catch { case e: Throwable =>
- // Catch all errors; run task failure callbacks, and rethrow the exception.
- context.markTaskFailed(e)
- throw e
+ } catch {
+ case e: Throwable =>
+ // Catch all errors; run task failure callbacks, and rethrow the exception.
+ try {
+ context.markTaskFailed(e)
+ } catch {
+ case t: Throwable =>
+ e.addSuppressed(t)
+ }
+ throw e
} finally {
// Call the task completion callbacks.
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
- SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
@@ -205,6 +216,11 @@ private[spark] object Task {
dataOut.writeLong(timestamp)
}
+ // Write the task properties separately so it is available before full task deserialization.
+ val propBytes = Utils.serialize(task.localProperties)
+ dataOut.writeInt(propBytes.length)
+ dataOut.write(propBytes)
+
// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task)
@@ -220,7 +236,7 @@ private[spark] object Task {
* @return (taskFiles, taskJars, taskBytes)
*/
def deserializeWithDependencies(serializedTask: ByteBuffer)
- : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+ : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = {
val in = new ByteBufferInputStream(serializedTask)
val dataIn = new DataInputStream(in)
@@ -239,8 +255,13 @@ private[spark] object Task {
taskJars(dataIn.readUTF()) = dataIn.readLong()
}
+ val propLength = dataIn.readInt()
+ val propBytes = new Array[Byte](propLength)
+ dataIn.readFully(propBytes, 0, propLength)
+ val taskProps = Utils.deserialize[Properties](propBytes)
+
// Create a sub-buffer for the rest of the data, which is the serialized Task object
val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
- (taskFiles, taskJars, subBuffer)
+ (taskFiles, taskJars, taskProps, subBuffer)
}
}