aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala5
-rw-r--r--project/MimaExcludes.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala9
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala10
24 files changed, 138 insertions, 36 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 9ec5cedf25..f0d152f05a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -602,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Set a local property that affects jobs submitted from this thread, such as the
- * Spark fair scheduler pool.
+ * Set a local property that affects jobs submitted from this thread, such as the Spark fair
+ * scheduler pool. User-defined properties may also be set here. These properties are propagated
+ * through to worker tasks and can be accessed there via
+ * [[org.apache.spark.TaskContext#getLocalProperty]].
*/
def setLocalProperty(key: String, value: String) {
if (value == null) {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index bfcacbf229..757c1b5116 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.Serializable
+import java.util.Properties
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
@@ -64,7 +65,7 @@ object TaskContext {
* An empty task context that does not represent an actual task.
*/
private[spark] def empty(): TaskContextImpl = {
- new TaskContextImpl(0, 0, 0, 0, null, null)
+ new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
}
}
@@ -162,6 +163,12 @@ abstract class TaskContext extends Serializable {
*/
def taskAttemptId(): Long
+ /**
+ * Get a local property set upstream in the driver, or null if it is missing. See also
+ * [[org.apache.spark.SparkContext.setLocalProperty]].
+ */
+ def getLocalProperty(key: String): String
+
@DeveloperApi
def taskMetrics(): TaskMetrics
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index c9354b3e55..fa0b2d3d28 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import java.util.Properties
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
@@ -32,6 +34,7 @@ private[spark] class TaskContextImpl(
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
+ localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll())
extends TaskContext
@@ -118,6 +121,8 @@ private[spark] class TaskContextImpl(
override def isInterrupted(): Boolean = interrupted
+ override def getLocalProperty(key: String): String = localProperties.getProperty(key)
+
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index afa4d6093a..9f94fdef24 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
+import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
@@ -206,9 +207,16 @@ private[spark] class Executor(
startGCTime = computeTotalGcTime()
try {
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
+ val (taskFiles, taskJars, taskProps, taskBytes) =
+ Task.deserializeWithDependencies(serializedTask)
+
+ // Must be set before updateDependencies() is called, in case fetching dependencies
+ // requires access to properties contained within (e.g. for access control).
+ Executor.taskDeserializationProps.set(taskProps)
+
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task.localProperties = taskProps
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
@@ -506,3 +514,10 @@ private[spark] class Executor(
heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
}
}
+
+private[spark] object Executor {
+ // This is reserved for internal use by components that need to read task properties before a
+ // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
+ // used instead.
+ val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5cdc91316b..4609b244e6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1036,7 +1036,7 @@ class DAGScheduler(
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, stage.internalAccumulators)
+ taskBinary, part, locs, stage.internalAccumulators, properties)
}
case stage: ResultStage =>
@@ -1046,7 +1046,7 @@ class DAGScheduler(
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, id, stage.internalAccumulators)
+ taskBinary, part, locs, id, properties, stage.internalAccumulators)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index cd2736e196..db6276f75d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
@@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
@@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U](
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
+ localProperties: Properties,
_initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll())
- extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
+ extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index e30964a01b..b7cab7013e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import java.util.Properties
import scala.language.existentials
@@ -42,6 +43,7 @@ import org.apache.spark.shuffle.ShuffleWriter
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] class ShuffleMapTask(
stageId: Int,
@@ -49,13 +51,14 @@ private[spark] class ShuffleMapTask(
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
- _initialAccums: Seq[Accumulator[_]])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
+ _initialAccums: Seq[Accumulator[_]],
+ localProperties: Properties)
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index c91d8fbfc4..1ff9d7795f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.util.Properties
import scala.collection.mutable.HashMap
@@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
* @param initialAccumulators initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
- val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
+ val initialAccumulators: Seq[Accumulator[_]],
+ @transient var localProperties: Properties) extends Serializable {
/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
@@ -71,6 +74,7 @@ private[spark] abstract class Task[T](
taskAttemptId,
attemptNumber,
taskMemoryManager,
+ localProperties,
metricsSystem,
initialAccumulators)
TaskContext.setTaskContext(context)
@@ -212,6 +216,11 @@ private[spark] object Task {
dataOut.writeLong(timestamp)
}
+ // Write the task properties separately so it is available before full task deserialization.
+ val propBytes = Utils.serialize(task.localProperties)
+ dataOut.writeInt(propBytes.length)
+ dataOut.write(propBytes)
+
// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task)
@@ -227,7 +236,7 @@ private[spark] object Task {
* @return (taskFiles, taskJars, taskBytes)
*/
def deserializeWithDependencies(serializedTask: ByteBuffer)
- : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+ : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = {
val in = new ByteBufferInputStream(serializedTask)
val dataIn = new DataInputStream(in)
@@ -246,8 +255,13 @@ private[spark] object Task {
taskJars(dataIn.readUTF()) = dataIn.readLong()
}
+ val propLength = dataIn.readInt()
+ val propBytes = new Array[Byte](propLength)
+ dataIn.readFully(propBytes, 0, propLength)
+ val taskProps = Utils.deserialize[Properties](propBytes)
+
// Create a sub-buffer for the rest of the data, which is the serialized Task object
val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
- (taskFiles, taskJars, subBuffer)
+ (taskFiles, taskJars, taskProps, subBuffer)
}
}
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index ec192a8543..37879d11ca 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.Semaphore
import javax.annotation.concurrent.GuardedBy
@@ -292,7 +293,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
// Now we're on the executors.
// Deserialize the task and assert that its accumulators are zero'ed out.
- val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
+ val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
val taskDeser = serInstance.deserialize[DummyTask](
taskBytes, Thread.currentThread.getContextClassLoader)
// Assert that executors see only zeros
@@ -403,6 +404,6 @@ private class SaveInfoListener extends SparkListener {
private[spark] class DummyTask(
val internalAccums: Seq[Accumulator[_]],
val externalAccums: Seq[Accumulator[_]])
- extends Task[Int](0, 0, 0, internalAccums) {
+ extends Task[Int](0, 0, 0, internalAccums, new Properties) {
override def runTask(c: TaskContext): Int = 1
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 6ffa1c8ac1..00f3f15c45 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService}
import org.scalatest.Matchers
@@ -335,7 +336,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data1 = (1 to 10).map { x => x -> x}
@@ -343,7 +344,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data2 = (11 to 20).map { x => x -> x}
@@ -372,7 +373,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
- new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 2b5e4b80e9..362cd861cc 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -17,6 +17,8 @@
package org.apache.spark.memory
+import java.util.Properties
+
import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
/**
@@ -31,6 +33,7 @@ object MemoryTestingUtils {
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
metricsSystem = env.metricsSystem)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index f7e16af9d3..e3e6df6831 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.apache.spark.TaskContext
class FakeTask(
stageId: Int,
prefLocs: Seq[TaskLocation] = Nil)
- extends Task[Int](stageId, 0, 0, Seq.empty) {
+ extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 1dca4bd89f..76a7087645 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+import java.util.Properties
import org.apache.spark.TaskContext
@@ -25,7 +26,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+ extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index c4cf2f9f70..86911d2211 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.apache.spark._
-import org.apache.spark.executor.TaskMetricsSuite
+import org.apache.spark.executor.{Executor, TaskMetricsSuite}
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
@@ -59,7 +61,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -79,7 +82,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -170,9 +174,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val initialAccums = InternalAccumulator.createAll()
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
- val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) {
+ val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) {
context = new TaskContextImpl(0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
+ new Properties,
SparkEnv.get.metricsSystem,
initialAccums)
context.taskMetrics.registerAccumulator(acc1)
@@ -189,6 +194,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
}
+ test("localProperties are propagated to executors correctly") {
+ sc = new SparkContext("local", "test")
+ sc.setLocalProperty("testPropKey", "testPropValue")
+ val res = sc.parallelize(Array(1), 1).map(i => i).map(i => {
+ val inTask = TaskContext.get().getLocalProperty("testPropKey")
+ val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey")
+ s"$inTask,$inDeser"
+ }).collect()
+ assert(res === Array("testPropValue,testPropValue"))
+ }
+
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 167d3fd2e4..ade8e84d84 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler
-import java.util.Random
+import java.util.{Properties, Random}
import scala.collection.Map
import scala.collection.mutable
@@ -138,7 +138,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 7ee76aa4c6..9d1bd7ec89 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.storage
+import java.util.Properties
+
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.implicitConversions
import scala.reflect.ClassTag
@@ -58,7 +60,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
- TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null))
+ TaskContext.setTaskContext(
+ new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 290de794dc..a30581eb48 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -625,6 +625,9 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this")
+ ) ++ Seq(
+ // [SPARK-14475] Propagate user-defined context from driver to executors
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 4dc7d3461c..c1555114e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.Properties
+
import scala.collection.mutable
import scala.util.{Random, Try}
import scala.util.control.NonFatal
@@ -71,6 +73,7 @@ class UnsafeFixedWidthAggregationMapSuite
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
metricsSystem = null))
try {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 476d93fc2a..03d4be8ee5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.Properties
+
import scala.util.Random
import org.apache.spark._
@@ -117,6 +119,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
taskAttemptId = 98456,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
+ localProperties = new Properties,
metricsSystem = null))
val sorter = new UnsafeKVExternalSorter(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 1f3779373b..7db1f9654b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File}
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.memory.TaskMemoryManager
@@ -113,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
val taskContext = new TaskContextImpl(
- 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc))
+ 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.create(sc))
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index cc187f5cb4..928739a416 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming
import java.io.{InputStream, NotSerializableException}
+import java.util.Properties
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import scala.collection.Map
@@ -25,6 +26,7 @@ import scala.collection.mutable.Queue
import scala.reflect.ClassTag
import scala.util.control.NonFatal
+import org.apache.commons.lang.SerializationUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable, Text}
@@ -198,6 +200,10 @@ class StreamingContext private[streaming] (
private val startSite = new AtomicReference[CallSite](null)
+ // Copy of thread-local properties from SparkContext. These properties will be set in all tasks
+ // submitted by this StreamingContext after start.
+ private[streaming] val savedProperties = new AtomicReference[Properties](new Properties)
+
private[streaming] def getStartSite(): CallSite = startSite.get()
private var shutdownHookRef: AnyRef = _
@@ -573,6 +579,8 @@ class StreamingContext private[streaming] (
sparkContext.setCallSite(startSite.get)
sparkContext.clearJobGroup()
sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
+ savedProperties.set(SerializationUtils.clone(
+ sparkContext.localProperties.get()).asInstanceOf[Properties])
scheduler.start()
}
state = StreamingContextState.ACTIVE
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 86f069b0bd..307ff1f7ec 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -241,11 +241,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
- // Set the SparkEnv in this thread, so that job generation code can access the environment
- // Example: BlockRDDs are created in this thread, and it needs to access BlockManager
- // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
- SparkEnv.set(ssc.env)
-
// Checkpoint all RDDs marked for checkpointing to ensure their lineages are
// truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 303c325274..ac18f73ea8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -17,11 +17,14 @@
package org.apache.spark.streaming.scheduler
+import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
import scala.util.Failure
+import org.apache.commons.lang.SerializationUtils
+
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{PairRDDFunctions, RDD}
import org.apache.spark.streaming._
@@ -214,7 +217,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
import JobScheduler._
def run() {
+ val oldProps = ssc.sparkContext.getLocalProperties
try {
+ ssc.sparkContext.setLocalProperties(
+ SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties])
val formattedTime = UIUtils.formatBatchTime(
job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false)
val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}"
@@ -248,8 +254,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
// JobScheduler has been stopped.
}
} finally {
- ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null)
- ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null)
+ ssc.sparkContext.setLocalProperties(oldProps)
}
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index a80154e2fc..806e181f61 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -182,7 +182,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(ssc.scheduler.isStarted === false)
}
- test("start should set job group and description of streaming jobs correctly") {
+ test("start should set local properties of streaming jobs correctly") {
ssc = new StreamingContext(conf, batchDuration)
ssc.sc.setJobGroup("non-streaming", "non-streaming", true)
val sc = ssc.sc
@@ -190,16 +190,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
@volatile var jobGroupFound: String = ""
@volatile var jobDescFound: String = ""
@volatile var jobInterruptFound: String = ""
+ @volatile var customPropFound: String = ""
@volatile var allFound: Boolean = false
addInputStream(ssc).foreachRDD { rdd =>
jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION)
jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL)
+ customPropFound = sc.getLocalProperty("customPropKey")
allFound = true
}
+ ssc.sc.setLocalProperty("customPropKey", "value1")
ssc.start()
+ // Local props set after start should be ignored
+ ssc.sc.setLocalProperty("customPropKey", "value2")
+
eventually(timeout(10 seconds), interval(10 milliseconds)) {
assert(allFound === true)
}
@@ -208,11 +214,13 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(jobGroupFound === null)
assert(jobDescFound.contains("Streaming job from"))
assert(jobInterruptFound === "false")
+ assert(customPropFound === "value1")
// Verify current thread's thread-local properties have not changed
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming")
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming")
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true")
+ assert(sc.getLocalProperty("customPropKey") === "value2")
}
test("start multiple times") {