aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2014-01-12 20:04:21 -0800
committerPatrick Wendell <pwendell@gmail.com>2014-01-12 20:04:21 -0800
commit405bfe86ef9c3021358d2ac89192857478861fe0 (patch)
treebeb56b9652ca10c82d7bc165aadc43af54cda940 /streaming
parent28a6b0cdbc75d58e36b1da3dcf257c61e44b0f7a (diff)
parentaa2c993858f87adc249eb9c20a908a125f8f4033 (diff)
downloadspark-405bfe86ef9c3021358d2ac89192857478861fe0.tar.gz
spark-405bfe86ef9c3021358d2ac89192857478861fe0.tar.bz2
spark-405bfe86ef9c3021358d2ac89192857478861fe0.zip
Merge pull request #394 from tdas/error-handling
Better error handling in Spark Streaming and more API cleanup Earlier errors in jobs generated by Spark Streaming (or in the generation of jobs) could not be caught from the main driver thread (i.e. the thread that called StreamingContext.start()) as it would be thrown in different threads. With this change, after `ssc.start`, one can call `ssc.awaitTermination()` which will be block until the ssc is closed, or there is an exception. This makes it easier to debug. This change also adds ssc.stop(<stop-spark-context>) where you can stop StreamingContext without stopping the SparkContext. Also fixes the bug that came up with PRs #393 and #381. MetadataCleaner default value has been changed from 3500 to -1 for normal SparkContext and 3600 when creating a StreamingContext. Also, updated StreamingListenerBus with changes similar to SparkListenerBus in #392. And changed a lot of protected[streaming] to private[streaming].
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala28
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala62
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala112
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala21
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala9
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala81
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala141
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala18
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala40
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala21
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala13
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala6
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala22
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala218
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala7
20 files changed, 588 insertions, 228 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 1249ef4c3d..108bc2de3e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -40,7 +40,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val graph = ssc.graph
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
- val pendingTimes = ssc.scheduler.getPendingTimes()
+ val pendingTimes = ssc.scheduler.getPendingTimes().toArray
val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf)
val sparkConf = ssc.conf
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
new file mode 100644
index 0000000000..1f5dacb543
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
@@ -0,0 +1,28 @@
+package org.apache.spark.streaming
+
+private[streaming] class ContextWaiter {
+ private var error: Throwable = null
+ private var stopped: Boolean = false
+
+ def notifyError(e: Throwable) = synchronized {
+ error = e
+ notifyAll()
+ }
+
+ def notifyStop() = synchronized {
+ notifyAll()
+ }
+
+ def waitForStopOrError(timeout: Long = -1) = synchronized {
+ // If already had error, then throw it
+ if (error != null) {
+ throw error
+ }
+
+ // If not already stopped, then wait
+ if (!stopped) {
+ if (timeout < 0) wait() else wait(timeout)
+ if (error != null) throw error
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 9432a709d0..f760093579 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -54,7 +54,7 @@ import org.apache.spark.util.MetadataCleaner
*/
abstract class DStream[T: ClassTag] (
- @transient protected[streaming] var ssc: StreamingContext
+ @transient private[streaming] var ssc: StreamingContext
) extends Serializable with Logging {
// =======================================================================
@@ -74,31 +74,31 @@ abstract class DStream[T: ClassTag] (
// Methods and fields available on all DStreams
// =======================================================================
- // RDDs generated, marked as protected[streaming] so that testsuites can access it
+ // RDDs generated, marked as private[streaming] so that testsuites can access it
@transient
- protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
+ private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
// Time zero for the DStream
- protected[streaming] var zeroTime: Time = null
+ private[streaming] var zeroTime: Time = null
// Duration for which the DStream will remember each RDD created
- protected[streaming] var rememberDuration: Duration = null
+ private[streaming] var rememberDuration: Duration = null
// Storage level of the RDDs in the stream
- protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
+ private[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
// Checkpoint details
- protected[streaming] val mustCheckpoint = false
- protected[streaming] var checkpointDuration: Duration = null
- protected[streaming] val checkpointData = new DStreamCheckpointData(this)
+ private[streaming] val mustCheckpoint = false
+ private[streaming] var checkpointDuration: Duration = null
+ private[streaming] val checkpointData = new DStreamCheckpointData(this)
// Reference to whole DStream graph
- protected[streaming] var graph: DStreamGraph = null
+ private[streaming] var graph: DStreamGraph = null
- protected[streaming] def isInitialized = (zeroTime != null)
+ private[streaming] def isInitialized = (zeroTime != null)
// Duration for which the DStream requires its parent DStream to remember each RDD created
- protected[streaming] def parentRememberDuration = rememberDuration
+ private[streaming] def parentRememberDuration = rememberDuration
/** Return the StreamingContext associated with this DStream */
def context = ssc
@@ -138,7 +138,7 @@ abstract class DStream[T: ClassTag] (
* the validity of future times is calculated. This method also recursively initializes
* its parent DStreams.
*/
- protected[streaming] def initialize(time: Time) {
+ private[streaming] def initialize(time: Time) {
if (zeroTime != null && zeroTime != time) {
throw new Exception("ZeroTime is already initialized to " + zeroTime
+ ", cannot initialize it again to " + time)
@@ -164,7 +164,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.initialize(zeroTime))
}
- protected[streaming] def validate() {
+ private[streaming] def validate() {
assert(rememberDuration != null, "Remember duration is set to null")
assert(
@@ -228,7 +228,7 @@ abstract class DStream[T: ClassTag] (
logInfo("Initialized and validated " + this)
}
- protected[streaming] def setContext(s: StreamingContext) {
+ private[streaming] def setContext(s: StreamingContext) {
if (ssc != null && ssc != s) {
throw new Exception("Context is already set in " + this + ", cannot set it again")
}
@@ -237,7 +237,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.setContext(ssc))
}
- protected[streaming] def setGraph(g: DStreamGraph) {
+ private[streaming] def setGraph(g: DStreamGraph) {
if (graph != null && graph != g) {
throw new Exception("Graph is already set in " + this + ", cannot set it again")
}
@@ -245,7 +245,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.setGraph(graph))
}
- protected[streaming] def remember(duration: Duration) {
+ private[streaming] def remember(duration: Duration) {
if (duration != null && duration > rememberDuration) {
rememberDuration = duration
logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
@@ -254,14 +254,14 @@ abstract class DStream[T: ClassTag] (
}
/** Checks whether the 'time' is valid wrt slideDuration for generating RDD */
- protected def isTimeValid(time: Time): Boolean = {
+ private[streaming] def isTimeValid(time: Time): Boolean = {
if (!isInitialized) {
throw new Exception (this + " has not been initialized")
} else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) {
logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime))
false
} else {
- logInfo("Time " + time + " is valid")
+ logDebug("Time " + time + " is valid")
true
}
}
@@ -270,7 +270,7 @@ abstract class DStream[T: ClassTag] (
* Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
* method that should not be called directly.
*/
- protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
+ private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
// If this DStream was not initialized (i.e., zeroTime not set), then do it
// If RDD was already generated, then retrieve it from HashMap
generatedRDDs.get(time) match {
@@ -311,7 +311,7 @@ abstract class DStream[T: ClassTag] (
* that materializes the corresponding RDD. Subclasses of DStream may override this
* to generate their own jobs.
*/
- protected[streaming] def generateJob(time: Time): Option[Job] = {
+ private[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
case Some(rdd) => {
val jobFunc = () => {
@@ -330,7 +330,7 @@ abstract class DStream[T: ClassTag] (
* implementation clears the old generated RDDs. Subclasses of DStream may override
* this to clear their own metadata along with the generated RDDs.
*/
- protected[streaming] def clearMetadata(time: Time) {
+ private[streaming] def clearMetadata(time: Time) {
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
generatedRDDs --= oldRDDs.keys
logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +
@@ -339,9 +339,9 @@ abstract class DStream[T: ClassTag] (
}
/* Adds metadata to the Stream while it is running.
- * This methd should be overwritten by sublcasses of InputDStream.
+ * This method should be overwritten by sublcasses of InputDStream.
*/
- protected[streaming] def addMetadata(metadata: Any) {
+ private[streaming] def addMetadata(metadata: Any) {
if (metadata != null) {
logInfo("Dropping Metadata: " + metadata.toString)
}
@@ -354,18 +354,18 @@ abstract class DStream[T: ClassTag] (
* checkpointData. Subclasses of DStream (especially those of InputDStream) may override
* this method to save custom checkpoint data.
*/
- protected[streaming] def updateCheckpointData(currentTime: Time) {
- logInfo("Updating checkpoint data for time " + currentTime)
+ private[streaming] def updateCheckpointData(currentTime: Time) {
+ logDebug("Updating checkpoint data for time " + currentTime)
checkpointData.update(currentTime)
dependencies.foreach(_.updateCheckpointData(currentTime))
logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData)
}
- protected[streaming] def clearCheckpointData(time: Time) {
- logInfo("Clearing checkpoint data")
+ private[streaming] def clearCheckpointData(time: Time) {
+ logDebug("Clearing checkpoint data")
checkpointData.cleanup(time)
dependencies.foreach(_.clearCheckpointData(time))
- logInfo("Cleared checkpoint data")
+ logDebug("Cleared checkpoint data")
}
/**
@@ -374,7 +374,7 @@ abstract class DStream[T: ClassTag] (
* from the checkpoint file names stored in checkpointData. Subclasses of DStream that
* override the updateCheckpointData() method would also need to override this method.
*/
- protected[streaming] def restoreCheckpointData() {
+ private[streaming] def restoreCheckpointData() {
// Create RDDs from the checkpoint data
logInfo("Restoring checkpoint data")
checkpointData.restore()
@@ -699,7 +699,7 @@ abstract class DStream[T: ClassTag] (
/**
* Return all the RDDs defined by the Interval object (both end times included)
*/
- protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = {
+ def slice(interval: Interval): Seq[RDD[T]] = {
slice(interval.beginTime, interval.endTime)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index eee9591ffc..668e5324e6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming
-import dstream.InputDStream
+import org.apache.spark.streaming.dstream.{NetworkInputDStream, InputDStream}
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import org.apache.spark.Logging
@@ -103,6 +103,12 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
def getOutputStreams() = this.synchronized { outputStreams.toArray }
+ def getNetworkInputStreams() = this.synchronized {
+ inputStreams.filter(_.isInstanceOf[NetworkInputDStream[_]])
+ .map(_.asInstanceOf[NetworkInputDStream[_]])
+ .toArray
+ }
+
def generateJobs(time: Time): Seq[Job] = {
logDebug("Generating jobs for time " + time)
val jobs = this.synchronized {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index dd34f6f4f2..ee83ae902b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -46,7 +46,7 @@ import org.apache.hadoop.conf.Configuration
* information (such as, cluster URL and job name) to internally create a SparkContext, it provides
* methods used to create DStream from various input sources.
*/
-class StreamingContext private (
+class StreamingContext private[streaming] (
sc_ : SparkContext,
cp_ : Checkpoint,
batchDur_ : Duration
@@ -101,20 +101,9 @@ class StreamingContext private (
"both SparkContext and checkpoint as null")
}
- private val conf_ = Option(sc_).map(_.conf).getOrElse(cp_.sparkConf)
+ private[streaming] val isCheckpointPresent = (cp_ != null)
- if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds(conf_) < 0) {
- MetadataCleaner.setDelaySeconds(conf_, cp_.delaySeconds)
- }
-
- if (MetadataCleaner.getDelaySeconds(conf_) < 0) {
- throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
- + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
- }
-
- protected[streaming] val isCheckpointPresent = (cp_ != null)
-
- protected[streaming] val sc: SparkContext = {
+ private[streaming] val sc: SparkContext = {
if (isCheckpointPresent) {
new SparkContext(cp_.sparkConf)
} else {
@@ -122,11 +111,16 @@ class StreamingContext private (
}
}
- protected[streaming] val conf = sc.conf
+ if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) {
+ throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
+ + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
+ }
+
+ private[streaming] val conf = sc.conf
- protected[streaming] val env = SparkEnv.get
+ private[streaming] val env = SparkEnv.get
- protected[streaming] val graph: DStreamGraph = {
+ private[streaming] val graph: DStreamGraph = {
if (isCheckpointPresent) {
cp_.graph.setContext(this)
cp_.graph.restoreCheckpointData()
@@ -139,10 +133,9 @@ class StreamingContext private (
}
}
- protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0)
- protected[streaming] var networkInputTracker: NetworkInputTracker = null
+ private val nextNetworkInputStreamId = new AtomicInteger(0)
- protected[streaming] var checkpointDir: String = {
+ private[streaming] var checkpointDir: String = {
if (isCheckpointPresent) {
sc.setCheckpointDir(cp_.checkpointDir)
cp_.checkpointDir
@@ -151,11 +144,13 @@ class StreamingContext private (
}
}
- protected[streaming] val checkpointDuration: Duration = {
+ private[streaming] val checkpointDuration: Duration = {
if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration
}
- protected[streaming] val scheduler = new JobScheduler(this)
+ private[streaming] val scheduler = new JobScheduler(this)
+
+ private[streaming] val waiter = new ContextWaiter
/**
* Return the associated Spark context
*/
@@ -191,11 +186,11 @@ class StreamingContext private (
}
}
- protected[streaming] def initialCheckpoint: Checkpoint = {
+ private[streaming] def initialCheckpoint: Checkpoint = {
if (isCheckpointPresent) cp_ else null
}
- protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+ private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
/**
* Create an input stream with any arbitrary user implemented network receiver.
@@ -416,7 +411,7 @@ class StreamingContext private (
scheduler.listenerBus.addListener(streamingListener)
}
- protected def validate() {
+ private def validate() {
assert(graph != null, "Graph is null")
graph.validate()
@@ -430,38 +425,37 @@ class StreamingContext private (
/**
* Start the execution of the streams.
*/
- def start() {
+ def start() = synchronized {
validate()
+ scheduler.start()
+ }
- // Get the network input streams
- val networkInputStreams = graph.getInputStreams().filter(s => s match {
- case n: NetworkInputDStream[_] => true
- case _ => false
- }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
-
- // Start the network input tracker (must start before receivers)
- if (networkInputStreams.length > 0) {
- networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
- networkInputTracker.start()
- }
- Thread.sleep(1000)
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ */
+ def awaitTermination() {
+ waiter.waitForStopOrError()
+ }
- // Start the scheduler
- scheduler.start()
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ * @param timeout time to wait in milliseconds
+ */
+ def awaitTermination(timeout: Long) {
+ waiter.waitForStopOrError(timeout)
}
/**
* Stop the execution of the streams.
+ * @param stopSparkContext Stop the associated SparkContext or not
*/
- def stop() {
- try {
- if (scheduler != null) scheduler.stop()
- if (networkInputTracker != null) networkInputTracker.stop()
- sc.stop()
- logInfo("StreamingContext stopped successfully")
- } catch {
- case e: Exception => logWarning("Error while stopping", e)
- }
+ def stop(stopSparkContext: Boolean = true) = synchronized {
+ scheduler.stop()
+ logInfo("StreamingContext stopped successfully")
+ waiter.notifyStop()
+ if (stopSparkContext) sc.stop()
}
}
@@ -472,6 +466,8 @@ class StreamingContext private (
object StreamingContext extends Logging {
+ private[streaming] val DEFAULT_CLEANER_TTL = 3600
+
implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
@@ -515,37 +511,29 @@ object StreamingContext extends Logging {
*/
def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls)
-
- protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
+ private[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second batch intervals.
if (MetadataCleaner.getDelaySeconds(conf) < 0) {
- MetadataCleaner.setDelaySeconds(conf, 3600)
+ MetadataCleaner.setDelaySeconds(conf, DEFAULT_CLEANER_TTL)
}
val sc = new SparkContext(conf)
sc
}
- protected[streaming] def createNewSparkContext(
+ private[streaming] def createNewSparkContext(
master: String,
appName: String,
sparkHome: String,
jars: Seq[String],
environment: Map[String, String]
): SparkContext = {
-
val conf = SparkContext.updatedConf(
new SparkConf(), master, appName, sparkHome, jars, environment)
- // Set the default cleaner delay to an hour if not already set.
- // This should be sufficient for even 1 second batch intervals.
- if (MetadataCleaner.getDelaySeconds(conf) < 0) {
- MetadataCleaner.setDelaySeconds(conf, 3600)
- }
- val sc = new SparkContext(master, appName, sparkHome, jars, environment)
- sc
+ createNewSparkContext(conf)
}
- protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
+ private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
if (prefix == null) {
time.milliseconds.toString
} else if (suffix == null || suffix.length ==0) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 523173d45a..b4c46f5e50 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -483,9 +483,28 @@ class JavaStreamingContext(val ssc: StreamingContext) {
def start() = ssc.start()
/**
- * Stop the execution of the streams.
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ */
+ def awaitTermination() = ssc.awaitTermination()
+
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ * @param timeout time to wait in milliseconds
+ */
+ def awaitTermination(timeout: Long) = ssc.awaitTermination(timeout)
+
+ /**
+ * Stop the execution of the streams. Will stop the associated JavaSparkContext as well.
*/
def stop() = ssc.stop()
+
+ /**
+ * Stop the execution of the streams.
+ * @param stopSparkContext Stop the associated SparkContext or not
+ */
+ def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext)
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index f01e67fe13..8f84232cab 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -43,7 +43,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
* This ensures that InputDStream.compute() is called strictly on increasing
* times.
*/
- override protected def isTimeValid(time: Time): Boolean = {
+ override private[streaming] def isTimeValid(time: Time): Boolean = {
if (!super.isTimeValid(time)) {
false // Time not valid
} else {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index d41f726f83..0f1f6fc2ce 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -68,7 +68,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
// then this returns an empty RDD. This may happen when recovering from a
// master failure
if (validTime >= graph.startTime) {
- val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
+ val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index c8ee93bf5b..7e0f6b2cdf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.Time
+import scala.util.Try
/**
* Class representing a Spark computation. It may contain multiple Spark jobs.
@@ -25,12 +26,10 @@ import org.apache.spark.streaming.Time
private[streaming]
class Job(val time: Time, func: () => _) {
var id: String = _
+ var result: Try[_] = null
- def run(): Long = {
- val startTime = System.currentTimeMillis
- func()
- val stopTime = System.currentTimeMillis
- (stopTime - startTime)
+ def run() {
+ result = Try(func())
}
def setId(number: Int) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 2fa6853ae0..b5f11d3440 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -17,11 +17,11 @@
package org.apache.spark.streaming.scheduler
-import akka.actor.{Props, Actor}
-import org.apache.spark.SparkEnv
-import org.apache.spark.Logging
+import akka.actor.{ActorRef, ActorSystem, Props, Actor}
+import org.apache.spark.{SparkException, SparkEnv, Logging}
import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
+import scala.util.{Failure, Success, Try}
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
@@ -37,29 +37,38 @@ private[scheduler] case class ClearCheckpointData(time: Time) extends JobGenerat
private[streaming]
class JobGenerator(jobScheduler: JobScheduler) extends Logging {
- val ssc = jobScheduler.ssc
- val graph = ssc.graph
- val eventProcessorActor = ssc.env.actorSystem.actorOf(Props(new Actor {
- def receive = {
- case event: JobGeneratorEvent =>
- logDebug("Got event of type " + event.getClass.getName)
- processEvent(event)
- }
- }))
+ private val ssc = jobScheduler.ssc
+ private val graph = ssc.graph
val clock = {
val clockClass = ssc.sc.conf.get(
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
Class.forName(clockClass).newInstance().asInstanceOf[Clock]
}
- val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
- longTime => eventProcessorActor ! GenerateJobs(new Time(longTime)))
- lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
+ private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
+ longTime => eventActor ! GenerateJobs(new Time(longTime)))
+ private lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
} else {
null
}
+ // eventActor is created when generator starts.
+ // This not being null means the scheduler has been started and not stopped
+ private var eventActor: ActorRef = null
+
+ /** Start generation of jobs */
def start() = synchronized {
+ if (eventActor != null) {
+ throw new SparkException("JobGenerator already started")
+ }
+
+ eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
+ def receive = {
+ case event: JobGeneratorEvent =>
+ logDebug("Got event of type " + event.getClass.getName)
+ processEvent(event)
+ }
+ }), "JobGenerator")
if (ssc.isCheckpointPresent) {
restart()
} else {
@@ -67,22 +76,26 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
}
- def stop() {
- timer.stop()
- if (checkpointWriter != null) checkpointWriter.stop()
- ssc.graph.stop()
- logInfo("JobGenerator stopped")
+ /** Stop generation of jobs */
+ def stop() = synchronized {
+ if (eventActor != null) {
+ timer.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ if (checkpointWriter != null) checkpointWriter.stop()
+ ssc.graph.stop()
+ logInfo("JobGenerator stopped")
+ }
}
/**
* On batch completion, clear old metadata and checkpoint computation.
*/
- private[scheduler] def onBatchCompletion(time: Time) {
- eventProcessorActor ! ClearMetadata(time)
+ def onBatchCompletion(time: Time) {
+ eventActor ! ClearMetadata(time)
}
- private[streaming] def onCheckpointCompletion(time: Time) {
- eventProcessorActor ! ClearCheckpointData(time)
+ def onCheckpointCompletion(time: Time) {
+ eventActor ! ClearCheckpointData(time)
}
/** Processes all events */
@@ -121,14 +134,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val checkpointTime = ssc.initialCheckpoint.checkpointTime
val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds))
val downTimes = checkpointTime.until(restartTime, batchDuration)
- logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", "))
+ logInfo("Batches during down time (" + downTimes.size + " batches): "
+ + downTimes.mkString(", "))
// Batches that were unprocessed before failure
val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering)
- logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", "))
+ logInfo("Batches pending processing (" + pendingTimes.size + " batches): " +
+ pendingTimes.mkString(", "))
// Reschedule jobs for these times
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
- logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", "))
+ logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " +
+ timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
jobScheduler.runJobs(time, graph.generateJobs(time))
)
@@ -141,15 +157,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
- logInfo("\n-----------------------------------------------------\n")
- jobScheduler.runJobs(time, graph.generateJobs(time))
- eventProcessorActor ! DoCheckpoint(time)
+ Try(graph.generateJobs(time)) match {
+ case Success(jobs) => jobScheduler.runJobs(time, jobs)
+ case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e)
+ }
+ eventActor ! DoCheckpoint(time)
}
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- eventProcessorActor ! DoCheckpoint(time)
+ eventActor ! DoCheckpoint(time)
}
/** Clear DStream checkpoint data for the given `time`. */
@@ -166,4 +184,3 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
}
}
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 30c070c274..de675d3c7f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -17,36 +17,68 @@
package org.apache.spark.streaming.scheduler
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import scala.util.{Failure, Success, Try}
+import scala.collection.JavaConversions._
import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors}
-import scala.collection.mutable.HashSet
+import akka.actor.{ActorRef, Actor, Props}
+import org.apache.spark.{SparkException, Logging, SparkEnv}
import org.apache.spark.streaming._
+
+private[scheduler] sealed trait JobSchedulerEvent
+private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent
+private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent
+private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent
+
/**
* This class schedules jobs to be run on Spark. It uses the JobGenerator to generate
- * the jobs and runs them using a thread pool. Number of threads
+ * the jobs and runs them using a thread pool.
*/
private[streaming]
class JobScheduler(val ssc: StreamingContext) extends Logging {
- val jobSets = new ConcurrentHashMap[Time, JobSet]
- val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
- val executor = Executors.newFixedThreadPool(numConcurrentJobs)
- val generator = new JobGenerator(this)
+ private val jobSets = new ConcurrentHashMap[Time, JobSet]
+ private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
+ private val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+ private val jobGenerator = new JobGenerator(this)
+ val clock = jobGenerator.clock
val listenerBus = new StreamingListenerBus()
- def clock = generator.clock
+ // These two are created only when scheduler starts.
+ // eventActor not being null means the scheduler has been started and not stopped
+ var networkInputTracker: NetworkInputTracker = null
+ private var eventActor: ActorRef = null
+
+
+ def start() = synchronized {
+ if (eventActor != null) {
+ throw new SparkException("JobScheduler already started")
+ }
- def start() {
- generator.start()
+ eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
+ def receive = {
+ case event: JobSchedulerEvent => processEvent(event)
+ }
+ }), "JobScheduler")
+ listenerBus.start()
+ networkInputTracker = new NetworkInputTracker(ssc)
+ networkInputTracker.start()
+ Thread.sleep(1000)
+ jobGenerator.start()
+ logInfo("JobScheduler started")
}
- def stop() {
- generator.stop()
- executor.shutdown()
- if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
- executor.shutdownNow()
+ def stop() = synchronized {
+ if (eventActor != null) {
+ jobGenerator.stop()
+ networkInputTracker.stop()
+ executor.shutdown()
+ if (!executor.awaitTermination(2, TimeUnit.SECONDS)) {
+ executor.shutdownNow()
+ }
+ listenerBus.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ logInfo("JobScheduler stopped")
}
}
@@ -61,46 +93,67 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
}
}
- def getPendingTimes(): Array[Time] = {
- jobSets.keySet.toArray(new Array[Time](0))
+ def getPendingTimes(): Seq[Time] = {
+ jobSets.keySet.toSeq
+ }
+
+ def reportError(msg: String, e: Throwable) {
+ eventActor ! ErrorReported(msg, e)
}
- private def beforeJobStart(job: Job) {
+ private def processEvent(event: JobSchedulerEvent) {
+ try {
+ event match {
+ case JobStarted(job) => handleJobStart(job)
+ case JobCompleted(job) => handleJobCompletion(job)
+ case ErrorReported(m, e) => handleError(m, e)
+ }
+ } catch {
+ case e: Throwable =>
+ reportError("Error in job scheduler", e)
+ }
+ }
+
+ private def handleJobStart(job: Job) {
val jobSet = jobSets.get(job.time)
if (!jobSet.hasStarted) {
- listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo()))
+ listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo))
}
- jobSet.beforeJobStart(job)
+ jobSet.handleJobStart(job)
logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
- SparkEnv.set(generator.ssc.env)
+ SparkEnv.set(ssc.env)
}
- private def afterJobEnd(job: Job) {
- val jobSet = jobSets.get(job.time)
- jobSet.afterJobStop(job)
- logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
- if (jobSet.hasCompleted) {
- jobSets.remove(jobSet.time)
- generator.onBatchCompletion(jobSet.time)
- logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
- jobSet.totalDelay / 1000.0, jobSet.time.toString,
- jobSet.processingDelay / 1000.0
- ))
- listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo()))
+ private def handleJobCompletion(job: Job) {
+ job.result match {
+ case Success(_) =>
+ val jobSet = jobSets.get(job.time)
+ jobSet.handleJobCompletion(job)
+ logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
+ if (jobSet.hasCompleted) {
+ jobSets.remove(jobSet.time)
+ jobGenerator.onBatchCompletion(jobSet.time)
+ logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
+ jobSet.totalDelay / 1000.0, jobSet.time.toString,
+ jobSet.processingDelay / 1000.0
+ ))
+ listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo))
+ }
+ case Failure(e) =>
+ reportError("Error running job " + job, e)
}
}
- private[streaming]
- class JobHandler(job: Job) extends Runnable {
+ private def handleError(msg: String, e: Throwable) {
+ logError(msg, e)
+ ssc.waiter.notifyError(e)
+ }
+
+ private class JobHandler(job: Job) extends Runnable {
def run() {
- beforeJobStart(job)
- try {
- job.run()
- } catch {
- case e: Exception =>
- logError("Running " + job + " failed", e)
- }
- afterJobEnd(job)
+ eventActor ! JobStarted(job)
+ job.run()
+ eventActor ! JobCompleted(job)
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
index 57268674ea..fcf303aee6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.scheduler
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.streaming.Time
/** Class representing a set of Jobs
@@ -27,25 +27,25 @@ private[streaming]
case class JobSet(time: Time, jobs: Seq[Job]) {
private val incompleteJobs = new HashSet[Job]()
- var submissionTime = System.currentTimeMillis() // when this jobset was submitted
- var processingStartTime = -1L // when the first job of this jobset started processing
- var processingEndTime = -1L // when the last job of this jobset finished processing
+ private val submissionTime = System.currentTimeMillis() // when this jobset was submitted
+ private var processingStartTime = -1L // when the first job of this jobset started processing
+ private var processingEndTime = -1L // when the last job of this jobset finished processing
jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) }
incompleteJobs ++= jobs
- def beforeJobStart(job: Job) {
+ def handleJobStart(job: Job) {
if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
}
- def afterJobStop(job: Job) {
+ def handleJobCompletion(job: Job) {
incompleteJobs -= job
if (hasCompleted) processingEndTime = System.currentTimeMillis()
}
- def hasStarted() = (processingStartTime > 0)
+ def hasStarted = processingStartTime > 0
- def hasCompleted() = incompleteJobs.isEmpty
+ def hasCompleted = incompleteJobs.isEmpty
// Time taken to process all the jobs from the time they started processing
// (i.e. not including the time they wait in the streaming scheduler queue)
@@ -57,7 +57,7 @@ case class JobSet(time: Time, jobs: Seq[Job]) {
processingEndTime - time.milliseconds
}
- def toBatchInfo(): BatchInfo = {
+ def toBatchInfo: BatchInfo = {
new BatchInfo(
time,
submissionTime,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index 75f7244643..0d9733fa69 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -19,8 +19,7 @@ package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkException, Logging, SparkEnv}
import org.apache.spark.SparkContext._
import scala.collection.mutable.HashMap
@@ -32,6 +31,7 @@ import akka.pattern.ask
import akka.dispatch._
import org.apache.spark.storage.BlockId
import org.apache.spark.streaming.{Time, StreamingContext}
+import org.apache.spark.util.AkkaUtils
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
@@ -39,33 +39,47 @@ private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], m
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
/**
- * This class manages the execution of the receivers of NetworkInputDStreams.
+ * This class manages the execution of the receivers of NetworkInputDStreams. Instance of
+ * this class must be created after all input streams have been added and StreamingContext.start()
+ * has been called because it needs the final set of input streams at the time of instantiation.
*/
private[streaming]
-class NetworkInputTracker(
- @transient ssc: StreamingContext,
- @transient networkInputStreams: Array[NetworkInputDStream[_]])
- extends Logging {
+class NetworkInputTracker(ssc: StreamingContext) extends Logging {
+ val networkInputStreams = ssc.graph.getNetworkInputStreams()
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
- val timeout = 5000.milliseconds
+ val timeout = AkkaUtils.askTimeout(ssc.conf)
+
+ // actor is created when generator starts.
+ // This not being null means the tracker has been started and not stopped
+ var actor: ActorRef = null
var currentTime: Time = null
/** Start the actor and receiver execution thread. */
def start() {
- ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
- receiverExecutor.start()
+ if (actor != null) {
+ throw new SparkException("NetworkInputTracker already started")
+ }
+
+ if (!networkInputStreams.isEmpty) {
+ actor = ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
+ receiverExecutor.start()
+ logInfo("NetworkInputTracker started")
+ }
}
/** Stop the receiver execution thread. */
def stop() {
- // TODO: stop the actor as well
- receiverExecutor.interrupt()
- receiverExecutor.stopReceivers()
+ if (!networkInputStreams.isEmpty && actor != null) {
+ receiverExecutor.interrupt()
+ receiverExecutor.stopReceivers()
+ ssc.env.actorSystem.stop(actor)
+ logInfo("NetworkInputTracker stopped")
+ }
}
/** Return all the blocks received from a receiver. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
index 36225e190c..461ea35064 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
@@ -24,9 +24,10 @@ import org.apache.spark.util.Distribution
sealed trait StreamingListenerEvent
case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent
-
case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
+/** An event used in the listener to shutdown the listener daemon thread. */
+private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent
/**
* A listener interface for receiving information about an ongoing streaming
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
index 110a20f282..6e6e22e1af 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
@@ -31,7 +31,7 @@ private[spark] class StreamingListenerBus() extends Logging {
private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY)
private var queueFullErrorMessageLogged = false
- new Thread("StreamingListenerBus") {
+ val listenerThread = new Thread("StreamingListenerBus") {
setDaemon(true)
override def run() {
while (true) {
@@ -41,11 +41,18 @@ private[spark] class StreamingListenerBus() extends Logging {
listeners.foreach(_.onBatchStarted(batchStarted))
case batchCompleted: StreamingListenerBatchCompleted =>
listeners.foreach(_.onBatchCompleted(batchCompleted))
+ case StreamingListenerShutdown =>
+ // Get out of the while loop and shutdown the daemon thread
+ return
case _ =>
}
}
}
- }.start()
+ }
+
+ def start() {
+ listenerThread.start()
+ }
def addListener(listener: StreamingListener) {
listeners += listener
@@ -54,9 +61,9 @@ private[spark] class StreamingListenerBus() extends Logging {
def post(event: StreamingListenerEvent) {
val eventAdded = eventQueue.offer(event)
if (!eventAdded && !queueFullErrorMessageLogged) {
- logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
- "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
- "rate at which tasks are being started by the scheduler.")
+ logError("Dropping StreamingListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the StreamingListeners is too slow and cannot keep up with the " +
+ "rate at which events are being started by the scheduler.")
queueFullErrorMessageLogged = true
}
}
@@ -68,7 +75,7 @@ private[spark] class StreamingListenerBus() extends Logging {
*/
def waitUntilEmpty(timeoutMillis: Int): Boolean = {
val finishTime = System.currentTimeMillis + timeoutMillis
- while (!eventQueue.isEmpty()) {
+ while (!eventQueue.isEmpty) {
if (System.currentTimeMillis > finishTime) {
return false
}
@@ -78,4 +85,6 @@ private[spark] class StreamingListenerBus() extends Logging {
}
return true
}
+
+ def stop(): Unit = post(StreamingListenerShutdown)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
index d644240405..559c247385 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
@@ -20,17 +20,7 @@ package org.apache.spark.streaming.util
private[streaming]
class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
- private val minPollTime = 25L
-
- private val pollTime = {
- if (period / 10.0 > minPollTime) {
- (period / 10.0).toLong
- } else {
- minPollTime
- }
- }
-
- private val thread = new Thread() {
+ private val thread = new Thread("RecurringTimer") {
override def run() { loop }
}
@@ -66,7 +56,6 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) =>
callback(nextTime)
nextTime += period
}
-
} catch {
case e: InterruptedException =>
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 9a187ce031..9406e0e20a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -375,11 +375,7 @@ class BasicOperationsSuite extends TestSuiteBase {
}
test("slice") {
- val conf2 = new SparkConf()
- .setMaster("local[2]")
- .setAppName("BasicOperationsSuite")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
- val ssc = new StreamingContext(new SparkContext(conf2), Seconds(1))
+ val ssc = new StreamingContext(conf, Seconds(1))
val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
val stream = new TestInputStream[Int](ssc, input, 2)
ssc.registerInputStream(stream)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 6499de98c9..9590bca989 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -28,6 +28,8 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.FileInputDStream
import org.apache.spark.streaming.util.ManualClock
+import org.apache.spark.util.Utils
+import org.apache.spark.SparkConf
/**
* This test suites tests the checkpointing functionality of DStreams -
@@ -142,6 +144,26 @@ class CheckpointSuite extends TestSuiteBase {
ssc = null
}
+ // This tests whether spark conf persists through checkpoints, and certain
+ // configs gets scrubbed
+ test("persistence of conf through checkpoints") {
+ val key = "spark.mykey"
+ val value = "myvalue"
+ System.setProperty(key, value)
+ ssc = new StreamingContext(master, framework, batchDuration)
+ val cp = new Checkpoint(ssc, Time(1000))
+ assert(!cp.sparkConf.contains("spark.driver.host"))
+ assert(!cp.sparkConf.contains("spark.driver.port"))
+ assert(!cp.sparkConf.contains("spark.hostPort"))
+ assert(cp.sparkConf.get(key) === value)
+ ssc.stop()
+ val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
+ assert(!newCp.sparkConf.contains("spark.driver.host"))
+ assert(!newCp.sparkConf.contains("spark.driver.port"))
+ assert(!newCp.sparkConf.contains("spark.hostPort"))
+ assert(newCp.sparkConf.get(key) === value)
+ }
+
// This tests whether the systm can recover from a master failure with simple
// non-stateful operations. This assumes as reliable, replayable input
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
new file mode 100644
index 0000000000..a477d200c9
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+import org.apache.spark.{SparkException, SparkConf, SparkContext}
+import org.apache.spark.util.{Utils, MetadataCleaner}
+
+class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
+
+ val master = "local[2]"
+ val appName = this.getClass.getSimpleName
+ val batchDuration = Seconds(1)
+ val sparkHome = "someDir"
+ val envPair = "key" -> "value"
+ val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100
+
+ var sc: SparkContext = null
+ var ssc: StreamingContext = null
+
+ before {
+ System.clearProperty("spark.cleaner.ttl")
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ ssc = null
+ }
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ }
+
+ test("from no conf constructor") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ assert(ssc.sparkContext.conf.get("spark.master") === master)
+ assert(ssc.sparkContext.conf.get("spark.app.name") === appName)
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from no conf + spark home") {
+ ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil)
+ assert(ssc.conf.get("spark.home") === sparkHome)
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from no conf + spark home + env") {
+ ssc = new StreamingContext(master, appName, batchDuration,
+ sparkHome, Nil, Map(envPair))
+ assert(ssc.conf.getExecutorEnv.exists(_ == envPair))
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from conf without ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(MetadataCleaner.getDelaySeconds(ssc.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from conf with ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl)
+ }
+
+ test("from existing SparkContext without ttl set") {
+ sc = new SparkContext(master, appName)
+ val exception = intercept[SparkException] {
+ ssc = new StreamingContext(sc, batchDuration)
+ }
+ assert(exception.getMessage.contains("ttl"))
+ }
+
+ test("from existing SparkContext with ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl)
+ }
+
+ test("from checkpoint") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ val ssc1 = new StreamingContext(myConf, batchDuration)
+ val cp = new Checkpoint(ssc1, Time(1000))
+ assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl)
+ ssc1.stop()
+ val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
+ assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl)
+ ssc = new StreamingContext(null, cp, null)
+ assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl)
+ }
+
+ test("start multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+
+ ssc.start()
+ intercept[SparkException] {
+ ssc.start()
+ }
+ }
+
+ test("stop multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+ ssc.start()
+ ssc.stop()
+ ssc.stop()
+ ssc = null
+ }
+
+ test("stop only streaming context") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ sc = ssc.sparkContext
+ addInputStream(ssc).register
+ ssc.start()
+ ssc.stop(false)
+ ssc = null
+ assert(sc.makeRDD(1 to 100).collect().size === 100)
+ ssc = new StreamingContext(sc, batchDuration)
+ }
+
+ test("awaitTermination") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => x).register
+
+ // test whether start() blocks indefinitely or not
+ failAfter(2000 millis) {
+ ssc.start()
+ }
+
+ // test whether waitForStop() exits after give amount of time
+ failAfter(1000 millis) {
+ ssc.awaitTermination(500)
+ }
+
+ // test whether waitForStop() does not exit if not time is given
+ val exception = intercept[Exception] {
+ failAfter(1000 millis) {
+ ssc.awaitTermination()
+ throw new Exception("Did not wait for stop")
+ }
+ }
+ assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop")
+
+ // test whether wait exits if context is stopped
+ failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown
+ new Thread() {
+ override def run {
+ Thread.sleep(500)
+ ssc.stop()
+ }
+ }.start()
+ ssc.awaitTermination()
+ }
+ }
+
+ test("awaitTermination with error in task") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => { throw new TestException("error in map task"); x})
+ .foreach(_.count)
+
+ val exception = intercept[Exception] {
+ ssc.start()
+ ssc.awaitTermination(5000)
+ }
+ assert(exception.getMessage.contains("map task"), "Expected exception not thrown")
+ }
+
+ test("awaitTermination with error in job generation") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+
+ inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register
+ val exception = intercept[TestException] {
+ ssc.start()
+ ssc.awaitTermination(5000)
+ }
+ assert(exception.getMessage.contains("transform"), "Expected exception not thrown")
+ }
+
+ def addInputStream(s: StreamingContext): DStream[Int] = {
+ val input = (1 to 100).map(i => (1 to i))
+ val inputStream = new TestInputStream(s, input, 1)
+ s.registerInputStream(inputStream)
+ inputStream
+ }
+}
+
+class TestException(msg: String) extends Exception(msg) \ No newline at end of file
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index b20d02f996..63a07cfbdf 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -137,7 +137,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val conf = new SparkConf()
.setMaster(master)
.setAppName(framework)
- .set("spark.cleaner.ttl", "3600")
+ .set("spark.cleaner.ttl", StreamingContext.DEFAULT_CLEANER_TTL.toString)
// Default before function for any streaming test suite. Override this
// if you want to add your stuff to "before" (i.e., don't call before { } )
@@ -273,10 +273,11 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val startTime = System.currentTimeMillis()
while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput)
- Thread.sleep(10)
+ ssc.awaitTermination(50)
}
val timeTaken = System.currentTimeMillis() - startTime
-
+ logInfo("Output generated in " + timeTaken + " milliseconds")
+ output.foreach(x => logInfo("[" + x.mkString(",") + "]"))
assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")