diff options
author | Eric Liang <ekl@databricks.com> | 2016-04-11 18:33:54 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-11 18:33:54 -0700 |
commit | 6f27027d96ada29d8bb1d626f2cc7c856df3d597 (patch) | |
tree | 4a6c72e6dea1db7d4cdede7b966a68dd69b30b22 /core/src/main/scala | |
parent | 94de63053ecd709f44213d09bb43a8b2c5a8b4bb (diff) | |
download | spark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.tar.gz spark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.tar.bz2 spark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.zip |
[SPARK-14475] Propagate user-defined context from driver to executors
## What changes were proposed in this pull request?
This adds a new API call `TaskContext.getLocalProperty` for getting properties set in the driver from executors. These local properties are automatically propagated from the driver to executors. For streaming, the context for streaming tasks will be the initial driver context when ssc.start() is called.
## How was this patch tested?
Unit tests.
cc JoshRosen
Author: Eric Liang <ekl@databricks.com>
Closes #12248 from ericl/sc-2813.
Diffstat (limited to 'core/src/main/scala')
8 files changed, 62 insertions, 13 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ec5cedf25..f0d152f05a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -602,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Set a local property that affects jobs submitted from this thread, such as the - * Spark fair scheduler pool. + * Set a local property that affects jobs submitted from this thread, such as the Spark fair + * scheduler pool. User-defined properties may also be set here. These properties are propagated + * through to worker tasks and can be accessed there via + * [[org.apache.spark.TaskContext#getLocalProperty]]. */ def setLocalProperty(key: String, value: String) { if (value == null) { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index bfcacbf229..757c1b5116 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.Serializable +import java.util.Properties import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -64,7 +65,7 @@ object TaskContext { * An empty task context that does not represent an actual task. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, null) + new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) } } @@ -162,6 +163,12 @@ abstract class TaskContext extends Serializable { */ def taskAttemptId(): Long + /** + * Get a local property set upstream in the driver, or null if it is missing. See also + * [[org.apache.spark.SparkContext.setLocalProperty]]. + */ + def getLocalProperty(key: String): String + @DeveloperApi def taskMetrics(): TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c9354b3e55..fa0b2d3d28 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.util.Properties + import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics @@ -32,6 +34,7 @@ private[spark] class TaskContextImpl( override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, @transient private val metricsSystem: MetricsSystem, initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll()) extends TaskContext @@ -118,6 +121,8 @@ private[spark] class TaskContextImpl( override def isInterrupted(): Boolean = interrupted + override def getLocalProperty(key: String): String = localProperties.getProperty(key) + override def getMetricsSources(sourceName: String): Seq[Source] = metricsSystem.getSourcesByName(sourceName) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index afa4d6093a..9f94fdef24 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer +import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ @@ -206,9 +207,16 @@ private[spark] class Executor( startGCTime = computeTotalGcTime() try { - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + val (taskFiles, taskJars, taskProps, taskBytes) = + Task.deserializeWithDependencies(serializedTask) + + // Must be set before updateDependencies() is called, in case fetching dependencies + // requires access to properties contained within (e.g. for access control). + Executor.taskDeserializationProps.set(taskProps) + updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.localProperties = taskProps task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, @@ -506,3 +514,10 @@ private[spark] class Executor( heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } + +private[spark] object Executor { + // This is reserved for internal use by components that need to read task properties before a + // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be + // used instead. + val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties] +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5cdc91316b..4609b244e6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1036,7 +1036,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.internalAccumulators) + taskBinary, part, locs, stage.internalAccumulators, properties) } case stage: ResultStage => @@ -1046,7 +1046,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, stage.internalAccumulators) + taskBinary, part, locs, id, properties, stage.internalAccumulators) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index cd2736e196..db6276f75d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer +import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast @@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). + * @param localProperties copy of thread-local properties set by the user on the driver side. * @param _initialAccums initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. @@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, + localProperties: Properties, _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll()) - extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums) + extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index e30964a01b..b7cab7013e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.Properties import scala.language.existentials @@ -42,6 +43,7 @@ import org.apache.spark.shuffle.ShuffleWriter * @param _initialAccums initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -49,13 +51,14 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - _initialAccums: Seq[Accumulator[_]]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums) + _initialAccums: Seq[Accumulator[_]], + localProperties: Properties) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c91d8fbfc4..1ff9d7795f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.HashMap @@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * @param initialAccumulators initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - val initialAccumulators: Seq[Accumulator[_]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]], + @transient var localProperties: Properties) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -71,6 +74,7 @@ private[spark] abstract class Task[T]( taskAttemptId, attemptNumber, taskMemoryManager, + localProperties, metricsSystem, initialAccumulators) TaskContext.setTaskContext(context) @@ -212,6 +216,11 @@ private[spark] object Task { dataOut.writeLong(timestamp) } + // Write the task properties separately so it is available before full task deserialization. + val propBytes = Utils.serialize(task.localProperties) + dataOut.writeInt(propBytes.length) + dataOut.write(propBytes) + // Write the task itself and finish dataOut.flush() val taskBytes = serializer.serialize(task) @@ -227,7 +236,7 @@ private[spark] object Task { * @return (taskFiles, taskJars, taskBytes) */ def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { val in = new ByteBufferInputStream(serializedTask) val dataIn = new DataInputStream(in) @@ -246,8 +255,13 @@ private[spark] object Task { taskJars(dataIn.readUTF()) = dataIn.readLong() } + val propLength = dataIn.readInt() + val propBytes = new Array[Byte](propLength) + dataIn.readFully(propBytes, 0, propLength) + val taskProps = Utils.deserialize[Properties](propBytes) + // Create a sub-buffer for the rest of the data, which is the serialized Task object val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) + (taskFiles, taskJars, taskProps, subBuffer) } } |