aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-04-11 18:33:54 -0700
committerReynold Xin <rxin@databricks.com>2016-04-11 18:33:54 -0700
commit6f27027d96ada29d8bb1d626f2cc7c856df3d597 (patch)
tree4a6c72e6dea1db7d4cdede7b966a68dd69b30b22
parent94de63053ecd709f44213d09bb43a8b2c5a8b4bb (diff)
downloadspark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.tar.gz
spark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.tar.bz2
spark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.zip
[SPARK-14475] Propagate user-defined context from driver to executors
## What changes were proposed in this pull request? This adds a new API call `TaskContext.getLocalProperty` for getting properties set in the driver from executors. These local properties are automatically propagated from the driver to executors. For streaming, the context for streaming tasks will be the initial driver context when ssc.start() is called. ## How was this patch tested? Unit tests. cc JoshRosen Author: Eric Liang <ekl@databricks.com> Closes #12248 from ericl/sc-2813.
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala5
-rw-r--r--project/MimaExcludes.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala9
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala10
24 files changed, 138 insertions, 36 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 9ec5cedf25..f0d152f05a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -602,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Set a local property that affects jobs submitted from this thread, such as the
- * Spark fair scheduler pool.
+ * Set a local property that affects jobs submitted from this thread, such as the Spark fair
+ * scheduler pool. User-defined properties may also be set here. These properties are propagated
+ * through to worker tasks and can be accessed there via
+ * [[org.apache.spark.TaskContext#getLocalProperty]].
*/
def setLocalProperty(key: String, value: String) {
if (value == null) {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index bfcacbf229..757c1b5116 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.Serializable
+import java.util.Properties
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
@@ -64,7 +65,7 @@ object TaskContext {
* An empty task context that does not represent an actual task.
*/
private[spark] def empty(): TaskContextImpl = {
- new TaskContextImpl(0, 0, 0, 0, null, null)
+ new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
}
}
@@ -162,6 +163,12 @@ abstract class TaskContext extends Serializable {
*/
def taskAttemptId(): Long
+ /**
+ * Get a local property set upstream in the driver, or null if it is missing. See also
+ * [[org.apache.spark.SparkContext.setLocalProperty]].
+ */
+ def getLocalProperty(key: String): String
+
@DeveloperApi
def taskMetrics(): TaskMetrics
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index c9354b3e55..fa0b2d3d28 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import java.util.Properties
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
@@ -32,6 +34,7 @@ private[spark] class TaskContextImpl(
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
+ localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll())
extends TaskContext
@@ -118,6 +121,8 @@ private[spark] class TaskContextImpl(
override def isInterrupted(): Boolean = interrupted
+ override def getLocalProperty(key: String): String = localProperties.getProperty(key)
+
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index afa4d6093a..9f94fdef24 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
+import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
@@ -206,9 +207,16 @@ private[spark] class Executor(
startGCTime = computeTotalGcTime()
try {
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
+ val (taskFiles, taskJars, taskProps, taskBytes) =
+ Task.deserializeWithDependencies(serializedTask)
+
+ // Must be set before updateDependencies() is called, in case fetching dependencies
+ // requires access to properties contained within (e.g. for access control).
+ Executor.taskDeserializationProps.set(taskProps)
+
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task.localProperties = taskProps
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
@@ -506,3 +514,10 @@ private[spark] class Executor(
heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
}
}
+
+private[spark] object Executor {
+ // This is reserved for internal use by components that need to read task properties before a
+ // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
+ // used instead.
+ val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5cdc91316b..4609b244e6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1036,7 +1036,7 @@ class DAGScheduler(
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, stage.internalAccumulators)
+ taskBinary, part, locs, stage.internalAccumulators, properties)
}
case stage: ResultStage =>
@@ -1046,7 +1046,7 @@ class DAGScheduler(
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, id, stage.internalAccumulators)
+ taskBinary, part, locs, id, properties, stage.internalAccumulators)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index cd2736e196..db6276f75d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
@@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
@@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U](
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
+ localProperties: Properties,
_initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll())
- extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
+ extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index e30964a01b..b7cab7013e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import java.util.Properties
import scala.language.existentials
@@ -42,6 +43,7 @@ import org.apache.spark.shuffle.ShuffleWriter
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] class ShuffleMapTask(
stageId: Int,
@@ -49,13 +51,14 @@ private[spark] class ShuffleMapTask(
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
- _initialAccums: Seq[Accumulator[_]])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
+ _initialAccums: Seq[Accumulator[_]],
+ localProperties: Properties)
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index c91d8fbfc4..1ff9d7795f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.util.Properties
import scala.collection.mutable.HashMap
@@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
* @param initialAccumulators initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
- val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
+ val initialAccumulators: Seq[Accumulator[_]],
+ @transient var localProperties: Properties) extends Serializable {
/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
@@ -71,6 +74,7 @@ private[spark] abstract class Task[T](
taskAttemptId,
attemptNumber,
taskMemoryManager,
+ localProperties,
metricsSystem,
initialAccumulators)
TaskContext.setTaskContext(context)
@@ -212,6 +216,11 @@ private[spark] object Task {
dataOut.writeLong(timestamp)
}
+ // Write the task properties separately so it is available before full task deserialization.
+ val propBytes = Utils.serialize(task.localProperties)
+ dataOut.writeInt(propBytes.length)
+ dataOut.write(propBytes)
+
// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task)
@@ -227,7 +236,7 @@ private[spark] object Task {
* @return (taskFiles, taskJars, taskBytes)
*/
def deserializeWithDependencies(serializedTask: ByteBuffer)
- : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+ : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = {
val in = new ByteBufferInputStream(serializedTask)
val dataIn = new DataInputStream(in)
@@ -246,8 +255,13 @@ private[spark] object Task {
taskJars(dataIn.readUTF()) = dataIn.readLong()
}
+ val propLength = dataIn.readInt()
+ val propBytes = new Array[Byte](propLength)
+ dataIn.readFully(propBytes, 0, propLength)
+ val taskProps = Utils.deserialize[Properties](propBytes)
+
// Create a sub-buffer for the rest of the data, which is the serialized Task object
val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
- (taskFiles, taskJars, subBuffer)
+ (taskFiles, taskJars, taskProps, subBuffer)
}
}
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index ec192a8543..37879d11ca 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.Semaphore
import javax.annotation.concurrent.GuardedBy
@@ -292,7 +293,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
// Now we're on the executors.
// Deserialize the task and assert that its accumulators are zero'ed out.
- val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
+ val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
val taskDeser = serInstance.deserialize[DummyTask](
taskBytes, Thread.currentThread.getContextClassLoader)
// Assert that executors see only zeros
@@ -403,6 +404,6 @@ private class SaveInfoListener extends SparkListener {
private[spark] class DummyTask(
val internalAccums: Seq[Accumulator[_]],
val externalAccums: Seq[Accumulator[_]])
- extends Task[Int](0, 0, 0, internalAccums) {
+ extends Task[Int](0, 0, 0, internalAccums, new Properties) {
override def runTask(c: TaskContext): Int = 1
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 6ffa1c8ac1..00f3f15c45 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService}
import org.scalatest.Matchers
@@ -335,7 +336,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data1 = (1 to 10).map { x => x -> x}
@@ -343,7 +344,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data2 = (11 to 20).map { x => x -> x}
@@ -372,7 +373,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
- new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 2b5e4b80e9..362cd861cc 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -17,6 +17,8 @@
package org.apache.spark.memory
+import java.util.Properties
+
import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
/**
@@ -31,6 +33,7 @@ object MemoryTestingUtils {
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
metricsSystem = env.metricsSystem)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index f7e16af9d3..e3e6df6831 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.apache.spark.TaskContext
class FakeTask(
stageId: Int,
prefLocs: Seq[TaskLocation] = Nil)
- extends Task[Int](stageId, 0, 0, Seq.empty) {
+ extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 1dca4bd89f..76a7087645 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+import java.util.Properties
import org.apache.spark.TaskContext
@@ -25,7 +26,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+ extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index c4cf2f9f70..86911d2211 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.apache.spark._
-import org.apache.spark.executor.TaskMetricsSuite
+import org.apache.spark.executor.{Executor, TaskMetricsSuite}
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
@@ -59,7 +61,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -79,7 +82,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -170,9 +174,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val initialAccums = InternalAccumulator.createAll()
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
- val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) {
+ val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) {
context = new TaskContextImpl(0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
+ new Properties,
SparkEnv.get.metricsSystem,
initialAccums)
context.taskMetrics.registerAccumulator(acc1)
@@ -189,6 +194,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
}
+ test("localProperties are propagated to executors correctly") {
+ sc = new SparkContext("local", "test")
+ sc.setLocalProperty("testPropKey", "testPropValue")
+ val res = sc.parallelize(Array(1), 1).map(i => i).map(i => {
+ val inTask = TaskContext.get().getLocalProperty("testPropKey")
+ val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey")
+ s"$inTask,$inDeser"
+ }).collect()
+ assert(res === Array("testPropValue,testPropValue"))
+ }
+
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 167d3fd2e4..ade8e84d84 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler
-import java.util.Random
+import java.util.{Properties, Random}
import scala.collection.Map
import scala.collection.mutable
@@ -138,7 +138,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 7ee76aa4c6..9d1bd7ec89 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.storage
+import java.util.Properties
+
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.implicitConversions
import scala.reflect.ClassTag
@@ -58,7 +60,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
- TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null))
+ TaskContext.setTaskContext(
+ new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 290de794dc..a30581eb48 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -625,6 +625,9 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this")
+ ) ++ Seq(
+ // [SPARK-14475] Propagate user-defined context from driver to executors
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 4dc7d3461c..c1555114e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.Properties
+
import scala.collection.mutable
import scala.util.{Random, Try}
import scala.util.control.NonFatal
@@ -71,6 +73,7 @@ class UnsafeFixedWidthAggregationMapSuite
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
metricsSystem = null))
try {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 476d93fc2a..03d4be8ee5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.Properties
+
import scala.util.Random
import org.apache.spark._
@@ -117,6 +119,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
taskAttemptId = 98456,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
+ localProperties = new Properties,
metricsSystem = null))
val sorter = new UnsafeKVExternalSorter(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 1f3779373b..7db1f9654b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File}
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.memory.TaskMemoryManager
@@ -113,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
val taskContext = new TaskContextImpl(
- 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc))
+ 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.create(sc))
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index cc187f5cb4..928739a416 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming
import java.io.{InputStream, NotSerializableException}
+import java.util.Properties
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import scala.collection.Map
@@ -25,6 +26,7 @@ import scala.collection.mutable.Queue
import scala.reflect.ClassTag
import scala.util.control.NonFatal
+import org.apache.commons.lang.SerializationUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable, Text}
@@ -198,6 +200,10 @@ class StreamingContext private[streaming] (
private val startSite = new AtomicReference[CallSite](null)
+ // Copy of thread-local properties from SparkContext. These properties will be set in all tasks
+ // submitted by this StreamingContext after start.
+ private[streaming] val savedProperties = new AtomicReference[Properties](new Properties)
+
private[streaming] def getStartSite(): CallSite = startSite.get()
private var shutdownHookRef: AnyRef = _
@@ -573,6 +579,8 @@ class StreamingContext private[streaming] (
sparkContext.setCallSite(startSite.get)
sparkContext.clearJobGroup()
sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
+ savedProperties.set(SerializationUtils.clone(
+ sparkContext.localProperties.get()).asInstanceOf[Properties])
scheduler.start()
}
state = StreamingContextState.ACTIVE
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 86f069b0bd..307ff1f7ec 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -241,11 +241,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
- // Set the SparkEnv in this thread, so that job generation code can access the environment
- // Example: BlockRDDs are created in this thread, and it needs to access BlockManager
- // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
- SparkEnv.set(ssc.env)
-
// Checkpoint all RDDs marked for checkpointing to ensure their lineages are
// truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 303c325274..ac18f73ea8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -17,11 +17,14 @@
package org.apache.spark.streaming.scheduler
+import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
import scala.util.Failure
+import org.apache.commons.lang.SerializationUtils
+
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{PairRDDFunctions, RDD}
import org.apache.spark.streaming._
@@ -214,7 +217,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
import JobScheduler._
def run() {
+ val oldProps = ssc.sparkContext.getLocalProperties
try {
+ ssc.sparkContext.setLocalProperties(
+ SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties])
val formattedTime = UIUtils.formatBatchTime(
job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false)
val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}"
@@ -248,8 +254,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
// JobScheduler has been stopped.
}
} finally {
- ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null)
- ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null)
+ ssc.sparkContext.setLocalProperties(oldProps)
}
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index a80154e2fc..806e181f61 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -182,7 +182,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(ssc.scheduler.isStarted === false)
}
- test("start should set job group and description of streaming jobs correctly") {
+ test("start should set local properties of streaming jobs correctly") {
ssc = new StreamingContext(conf, batchDuration)
ssc.sc.setJobGroup("non-streaming", "non-streaming", true)
val sc = ssc.sc
@@ -190,16 +190,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
@volatile var jobGroupFound: String = ""
@volatile var jobDescFound: String = ""
@volatile var jobInterruptFound: String = ""
+ @volatile var customPropFound: String = ""
@volatile var allFound: Boolean = false
addInputStream(ssc).foreachRDD { rdd =>
jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION)
jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL)
+ customPropFound = sc.getLocalProperty("customPropKey")
allFound = true
}
+ ssc.sc.setLocalProperty("customPropKey", "value1")
ssc.start()
+ // Local props set after start should be ignored
+ ssc.sc.setLocalProperty("customPropKey", "value2")
+
eventually(timeout(10 seconds), interval(10 milliseconds)) {
assert(allFound === true)
}
@@ -208,11 +214,13 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(jobGroupFound === null)
assert(jobDescFound.contains("Streaming job from"))
assert(jobInterruptFound === "false")
+ assert(customPropFound === "value1")
// Verify current thread's thread-local properties have not changed
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming")
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming")
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true")
+ assert(sc.getLocalProperty("customPropKey") === "value2")
}
test("start multiple times") {