aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala20
8 files changed, 62 insertions, 13 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 9ec5cedf25..f0d152f05a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -602,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Set a local property that affects jobs submitted from this thread, such as the
- * Spark fair scheduler pool.
+ * Set a local property that affects jobs submitted from this thread, such as the Spark fair
+ * scheduler pool. User-defined properties may also be set here. These properties are propagated
+ * through to worker tasks and can be accessed there via
+ * [[org.apache.spark.TaskContext#getLocalProperty]].
*/
def setLocalProperty(key: String, value: String) {
if (value == null) {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index bfcacbf229..757c1b5116 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.Serializable
+import java.util.Properties
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
@@ -64,7 +65,7 @@ object TaskContext {
* An empty task context that does not represent an actual task.
*/
private[spark] def empty(): TaskContextImpl = {
- new TaskContextImpl(0, 0, 0, 0, null, null)
+ new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
}
}
@@ -162,6 +163,12 @@ abstract class TaskContext extends Serializable {
*/
def taskAttemptId(): Long
+ /**
+ * Get a local property set upstream in the driver, or null if it is missing. See also
+ * [[org.apache.spark.SparkContext.setLocalProperty]].
+ */
+ def getLocalProperty(key: String): String
+
@DeveloperApi
def taskMetrics(): TaskMetrics
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index c9354b3e55..fa0b2d3d28 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import java.util.Properties
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
@@ -32,6 +34,7 @@ private[spark] class TaskContextImpl(
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
+ localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll())
extends TaskContext
@@ -118,6 +121,8 @@ private[spark] class TaskContextImpl(
override def isInterrupted(): Boolean = interrupted
+ override def getLocalProperty(key: String): String = localProperties.getProperty(key)
+
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index afa4d6093a..9f94fdef24 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
+import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
@@ -206,9 +207,16 @@ private[spark] class Executor(
startGCTime = computeTotalGcTime()
try {
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
+ val (taskFiles, taskJars, taskProps, taskBytes) =
+ Task.deserializeWithDependencies(serializedTask)
+
+ // Must be set before updateDependencies() is called, in case fetching dependencies
+ // requires access to properties contained within (e.g. for access control).
+ Executor.taskDeserializationProps.set(taskProps)
+
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task.localProperties = taskProps
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
@@ -506,3 +514,10 @@ private[spark] class Executor(
heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
}
}
+
+private[spark] object Executor {
+ // This is reserved for internal use by components that need to read task properties before a
+ // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
+ // used instead.
+ val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5cdc91316b..4609b244e6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1036,7 +1036,7 @@ class DAGScheduler(
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, stage.internalAccumulators)
+ taskBinary, part, locs, stage.internalAccumulators, properties)
}
case stage: ResultStage =>
@@ -1046,7 +1046,7 @@ class DAGScheduler(
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, id, stage.internalAccumulators)
+ taskBinary, part, locs, id, properties, stage.internalAccumulators)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index cd2736e196..db6276f75d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
@@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
@@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U](
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
+ localProperties: Properties,
_initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll())
- extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
+ extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index e30964a01b..b7cab7013e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import java.util.Properties
import scala.language.existentials
@@ -42,6 +43,7 @@ import org.apache.spark.shuffle.ShuffleWriter
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] class ShuffleMapTask(
stageId: Int,
@@ -49,13 +51,14 @@ private[spark] class ShuffleMapTask(
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
- _initialAccums: Seq[Accumulator[_]])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
+ _initialAccums: Seq[Accumulator[_]],
+ localProperties: Properties)
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index c91d8fbfc4..1ff9d7795f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.util.Properties
import scala.collection.mutable.HashMap
@@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
* @param initialAccumulators initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
- val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
+ val initialAccumulators: Seq[Accumulator[_]],
+ @transient var localProperties: Properties) extends Serializable {
/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
@@ -71,6 +74,7 @@ private[spark] abstract class Task[T](
taskAttemptId,
attemptNumber,
taskMemoryManager,
+ localProperties,
metricsSystem,
initialAccumulators)
TaskContext.setTaskContext(context)
@@ -212,6 +216,11 @@ private[spark] object Task {
dataOut.writeLong(timestamp)
}
+ // Write the task properties separately so it is available before full task deserialization.
+ val propBytes = Utils.serialize(task.localProperties)
+ dataOut.writeInt(propBytes.length)
+ dataOut.write(propBytes)
+
// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task)
@@ -227,7 +236,7 @@ private[spark] object Task {
* @return (taskFiles, taskJars, taskBytes)
*/
def deserializeWithDependencies(serializedTask: ByteBuffer)
- : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+ : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = {
val in = new ByteBufferInputStream(serializedTask)
val dataIn = new DataInputStream(in)
@@ -246,8 +255,13 @@ private[spark] object Task {
taskJars(dataIn.readUTF()) = dataIn.readLong()
}
+ val propLength = dataIn.readInt()
+ val propBytes = new Array[Byte](propLength)
+ dataIn.readFully(propBytes, 0, propLength)
+ val taskProps = Utils.deserialize[Properties](propBytes)
+
// Create a sub-buffer for the rest of the data, which is the serialized Task object
val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
- (taskFiles, taskJars, subBuffer)
+ (taskFiles, taskJars, taskProps, subBuffer)
}
}