aboutsummaryrefslogtreecommitdiff
path: root/core/src/test
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-04-11 18:33:54 -0700
committerReynold Xin <rxin@databricks.com>2016-04-11 18:33:54 -0700
commit6f27027d96ada29d8bb1d626f2cc7c856df3d597 (patch)
tree4a6c72e6dea1db7d4cdede7b966a68dd69b30b22 /core/src/test
parent94de63053ecd709f44213d09bb43a8b2c5a8b4bb (diff)
downloadspark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.tar.gz
spark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.tar.bz2
spark-6f27027d96ada29d8bb1d626f2cc7c856df3d597.zip
[SPARK-14475] Propagate user-defined context from driver to executors
## What changes were proposed in this pull request? This adds a new API call `TaskContext.getLocalProperty` for getting properties set in the driver from executors. These local properties are automatically propagated from the driver to executors. For streaming, the context for streaming tasks will be the initial driver context when ssc.start() is called. ## How was this patch tested? Unit tests. cc JoshRosen Author: Eric Liang <ekl@databricks.com> Closes #12248 from ericl/sc-2813.
Diffstat (limited to 'core/src/test')
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala5
8 files changed, 41 insertions, 14 deletions
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index ec192a8543..37879d11ca 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.Semaphore
import javax.annotation.concurrent.GuardedBy
@@ -292,7 +293,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
// Now we're on the executors.
// Deserialize the task and assert that its accumulators are zero'ed out.
- val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
+ val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
val taskDeser = serInstance.deserialize[DummyTask](
taskBytes, Thread.currentThread.getContextClassLoader)
// Assert that executors see only zeros
@@ -403,6 +404,6 @@ private class SaveInfoListener extends SparkListener {
private[spark] class DummyTask(
val internalAccums: Seq[Accumulator[_]],
val externalAccums: Seq[Accumulator[_]])
- extends Task[Int](0, 0, 0, internalAccums) {
+ extends Task[Int](0, 0, 0, internalAccums, new Properties) {
override def runTask(c: TaskContext): Int = 1
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 6ffa1c8ac1..00f3f15c45 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService}
import org.scalatest.Matchers
@@ -335,7 +336,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data1 = (1 to 10).map { x => x -> x}
@@ -343,7 +344,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data2 = (11 to 20).map { x => x -> x}
@@ -372,7 +373,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
- new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+ new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 2b5e4b80e9..362cd861cc 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -17,6 +17,8 @@
package org.apache.spark.memory
+import java.util.Properties
+
import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
/**
@@ -31,6 +33,7 @@ object MemoryTestingUtils {
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
metricsSystem = env.metricsSystem)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index f7e16af9d3..e3e6df6831 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.apache.spark.TaskContext
class FakeTask(
stageId: Int,
prefLocs: Seq[TaskLocation] = Nil)
- extends Task[Int](stageId, 0, 0, Seq.empty) {
+ extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 1dca4bd89f..76a7087645 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+import java.util.Properties
import org.apache.spark.TaskContext
@@ -25,7 +26,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+ extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index c4cf2f9f70..86911d2211 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.apache.spark._
-import org.apache.spark.executor.TaskMetricsSuite
+import org.apache.spark.executor.{Executor, TaskMetricsSuite}
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
@@ -59,7 +61,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -79,7 +82,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -170,9 +174,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val initialAccums = InternalAccumulator.createAll()
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
- val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) {
+ val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) {
context = new TaskContextImpl(0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
+ new Properties,
SparkEnv.get.metricsSystem,
initialAccums)
context.taskMetrics.registerAccumulator(acc1)
@@ -189,6 +194,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
}
+ test("localProperties are propagated to executors correctly") {
+ sc = new SparkContext("local", "test")
+ sc.setLocalProperty("testPropKey", "testPropValue")
+ val res = sc.parallelize(Array(1), 1).map(i => i).map(i => {
+ val inTask = TaskContext.get().getLocalProperty("testPropKey")
+ val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey")
+ s"$inTask,$inDeser"
+ }).collect()
+ assert(res === Array("testPropValue,testPropValue"))
+ }
+
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 167d3fd2e4..ade8e84d84 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler
-import java.util.Random
+import java.util.{Properties, Random}
import scala.collection.Map
import scala.collection.mutable
@@ -138,7 +138,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 7ee76aa4c6..9d1bd7ec89 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.storage
+import java.util.Properties
+
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.implicitConversions
import scala.reflect.ClassTag
@@ -58,7 +60,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
- TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null))
+ TaskContext.setTaskContext(
+ new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()