aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala59
-rw-r--r--core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala24
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala15
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala32
4 files changed, 126 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index ca5624a3d8..22e291a2b4 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -21,6 +21,7 @@ package org.apache.spark.util
import java.util.concurrent._
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
+import scala.util.control.NonFatal
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
@@ -86,4 +87,62 @@ private[spark] object ThreadUtils {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
Executors.newSingleThreadScheduledExecutor(threadFactory)
}
+
+ /**
+ * Run a piece of code in a new thread and return the result. Exception in the new thread is
+ * thrown in the caller thread with an adjusted stack trace that removes references to this
+ * method for clarity. The exception stack traces will be like the following
+ *
+ * SomeException: exception-message
+ * at CallerClass.body-method (sourcefile.scala)
+ * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... ()
+ * at CallerClass.caller-method (sourcefile.scala)
+ * ...
+ */
+ def runInNewThread[T](
+ threadName: String,
+ isDaemon: Boolean = true)(body: => T): T = {
+ @volatile var exception: Option[Throwable] = None
+ @volatile var result: T = null.asInstanceOf[T]
+
+ val thread = new Thread(threadName) {
+ override def run(): Unit = {
+ try {
+ result = body
+ } catch {
+ case NonFatal(e) =>
+ exception = Some(e)
+ }
+ }
+ }
+ thread.setDaemon(isDaemon)
+ thread.start()
+ thread.join()
+
+ exception match {
+ case Some(realException) =>
+ // Remove the part of the stack that shows method calls into this helper method
+ // This means drop everything from the top until the stack element
+ // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`).
+ val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile(
+ ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1)
+
+ // Remove the part of the new thread stack that shows methods call from this helper method
+ val extraStackTrace = realException.getStackTrace.takeWhile(
+ ! _.getClassName.contains(this.getClass.getSimpleName))
+
+ // Combine the two stack traces, with a place holder just specifying that there
+ // was a helper method used, without any further details of the helper
+ val placeHolderStackElem = new StackTraceElement(
+ s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..",
+ " ", "", -1)
+ val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace
+
+ // Update the stack trace and rethrow the exception in the caller thread
+ realException.setStackTrace(finalStackTrace)
+ throw realException
+ case None =>
+ result
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
index 8c51e6b14b..620e4debf4 100644
--- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.util
import java.util.concurrent.{CountDownLatch, TimeUnit}
-import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
+import scala.concurrent.{Await, Future}
+import scala.util.Random
import org.apache.spark.SparkFunSuite
@@ -66,4 +67,25 @@ class ThreadUtilsSuite extends SparkFunSuite {
val futureThreadName = Await.result(f, 10.seconds)
assert(futureThreadName === callerThreadName)
}
+
+ test("runInNewThread") {
+ import ThreadUtils._
+ assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name")
+ assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true)
+ assert(
+ runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false
+ )
+ val uniqueExceptionMessage = "test" + Random.nextInt()
+ val exception = intercept[IllegalArgumentException] {
+ runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) }
+ }
+ assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage)
+ assert(exception.getStackTrace.mkString("\n").contains(
+ "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true,
+ "stack trace does not contain expected place holder"
+ )
+ assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false,
+ "stack trace contains unexpected references to ThreadUtils"
+ )
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index b496d1f341..6720ba4f72 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver}
import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener}
import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab}
-import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils}
+import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils}
/**
* Main entry point for Spark Streaming functionality. It provides methods used to create
@@ -588,12 +588,20 @@ class StreamingContext private[streaming] (
state match {
case INITIALIZED =>
startSite.set(DStream.getCreationSite())
- sparkContext.setCallSite(startSite.get)
StreamingContext.ACTIVATION_LOCK.synchronized {
StreamingContext.assertNoOtherContextIsActive()
try {
validate()
- scheduler.start()
+
+ // Start the streaming scheduler in a new thread, so that thread local properties
+ // like call sites and job groups can be reset without affecting those of the
+ // current thread.
+ ThreadUtils.runInNewThread("streaming-start") {
+ sparkContext.setCallSite(startSite.get)
+ sparkContext.clearJobGroup()
+ sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
+ scheduler.start()
+ }
state = StreamingContextState.ACTIVE
} catch {
case NonFatal(e) =>
@@ -618,6 +626,7 @@ class StreamingContext private[streaming] (
}
}
+
/**
* Wait for the execution to stop. Any exceptions that occurs during the execution
* will be thrown in this thread.
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index d26894e88f..3b9d0d15ea 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(ssc.scheduler.isStarted === false)
}
+ test("start should set job group and description of streaming jobs correctly") {
+ ssc = new StreamingContext(conf, batchDuration)
+ ssc.sc.setJobGroup("non-streaming", "non-streaming", true)
+ val sc = ssc.sc
+
+ @volatile var jobGroupFound: String = ""
+ @volatile var jobDescFound: String = ""
+ @volatile var jobInterruptFound: String = ""
+ @volatile var allFound: Boolean = false
+
+ addInputStream(ssc).foreachRDD { rdd =>
+ jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
+ jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION)
+ jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL)
+ allFound = true
+ }
+ ssc.start()
+
+ eventually(timeout(10 seconds), interval(10 milliseconds)) {
+ assert(allFound === true)
+ }
+
+ // Verify streaming jobs have expected thread-local properties
+ assert(jobGroupFound === null)
+ assert(jobDescFound === null)
+ assert(jobInterruptFound === "false")
+
+ // Verify current thread's thread-local properties have not changed
+ assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming")
+ assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming")
+ assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true")
+ }
test("start multiple times") {
ssc = new StreamingContext(master, appName, batchDuration)