aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2014-01-11 23:15:09 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2014-01-11 23:15:09 -0800
commitf5108ffc24eccd21f5d6dc4114ea47b0ab14ab14 (patch)
tree2b446a3e73398929ab01491e16adcba7ec654736 /streaming
parent4f39e79c23b32a411a0d5fdc86b5c17ab2250f8d (diff)
downloadspark-f5108ffc24eccd21f5d6dc4114ea47b0ab14ab14.tar.gz
spark-f5108ffc24eccd21f5d6dc4114ea47b0ab14ab14.tar.bz2
spark-f5108ffc24eccd21f5d6dc4114ea47b0ab14ab14.zip
Converted JobScheduler to use actors for event handling. Changed protected[streaming] to private[streaming] in StreamingContext and DStream. Added waitForStop to StreamingContext, and StreamingContextSuite.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala28
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala56
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala111
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala9
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala40
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala134
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala18
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala20
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala6
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala22
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala208
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala2
16 files changed, 484 insertions, 184 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 1249ef4c3d..108bc2de3e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -40,7 +40,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val graph = ssc.graph
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
- val pendingTimes = ssc.scheduler.getPendingTimes()
+ val pendingTimes = ssc.scheduler.getPendingTimes().toArray
val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf)
val sparkConf = ssc.conf
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
new file mode 100644
index 0000000000..1f5dacb543
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
@@ -0,0 +1,28 @@
+package org.apache.spark.streaming
+
+private[streaming] class ContextWaiter {
+ private var error: Throwable = null
+ private var stopped: Boolean = false
+
+ def notifyError(e: Throwable) = synchronized {
+ error = e
+ notifyAll()
+ }
+
+ def notifyStop() = synchronized {
+ notifyAll()
+ }
+
+ def waitForStopOrError(timeout: Long = -1) = synchronized {
+ // If already had error, then throw it
+ if (error != null) {
+ throw error
+ }
+
+ // If not already stopped, then wait
+ if (!stopped) {
+ if (timeout < 0) wait() else wait(timeout)
+ if (error != null) throw error
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index b98f4a5101..d59146e069 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -53,7 +53,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
*/
abstract class DStream[T: ClassTag] (
- @transient protected[streaming] var ssc: StreamingContext
+ @transient private[streaming] var ssc: StreamingContext
) extends Serializable with Logging {
// =======================================================================
@@ -73,31 +73,31 @@ abstract class DStream[T: ClassTag] (
// Methods and fields available on all DStreams
// =======================================================================
- // RDDs generated, marked as protected[streaming] so that testsuites can access it
+ // RDDs generated, marked as private[streaming] so that testsuites can access it
@transient
- protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
+ private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
// Time zero for the DStream
- protected[streaming] var zeroTime: Time = null
+ private[streaming] var zeroTime: Time = null
// Duration for which the DStream will remember each RDD created
- protected[streaming] var rememberDuration: Duration = null
+ private[streaming] var rememberDuration: Duration = null
// Storage level of the RDDs in the stream
- protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
+ private[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
// Checkpoint details
- protected[streaming] val mustCheckpoint = false
- protected[streaming] var checkpointDuration: Duration = null
- protected[streaming] val checkpointData = new DStreamCheckpointData(this)
+ private[streaming] val mustCheckpoint = false
+ private[streaming] var checkpointDuration: Duration = null
+ private[streaming] val checkpointData = new DStreamCheckpointData(this)
// Reference to whole DStream graph
- protected[streaming] var graph: DStreamGraph = null
+ private[streaming] var graph: DStreamGraph = null
- protected[streaming] def isInitialized = (zeroTime != null)
+ private[streaming] def isInitialized = (zeroTime != null)
// Duration for which the DStream requires its parent DStream to remember each RDD created
- protected[streaming] def parentRememberDuration = rememberDuration
+ private[streaming] def parentRememberDuration = rememberDuration
/** Return the StreamingContext associated with this DStream */
def context = ssc
@@ -137,7 +137,7 @@ abstract class DStream[T: ClassTag] (
* the validity of future times is calculated. This method also recursively initializes
* its parent DStreams.
*/
- protected[streaming] def initialize(time: Time) {
+ private[streaming] def initialize(time: Time) {
if (zeroTime != null && zeroTime != time) {
throw new Exception("ZeroTime is already initialized to " + zeroTime
+ ", cannot initialize it again to " + time)
@@ -163,7 +163,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.initialize(zeroTime))
}
- protected[streaming] def validate() {
+ private[streaming] def validate() {
assert(rememberDuration != null, "Remember duration is set to null")
assert(
@@ -227,7 +227,7 @@ abstract class DStream[T: ClassTag] (
logInfo("Initialized and validated " + this)
}
- protected[streaming] def setContext(s: StreamingContext) {
+ private[streaming] def setContext(s: StreamingContext) {
if (ssc != null && ssc != s) {
throw new Exception("Context is already set in " + this + ", cannot set it again")
}
@@ -236,7 +236,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.setContext(ssc))
}
- protected[streaming] def setGraph(g: DStreamGraph) {
+ private[streaming] def setGraph(g: DStreamGraph) {
if (graph != null && graph != g) {
throw new Exception("Graph is already set in " + this + ", cannot set it again")
}
@@ -244,7 +244,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.setGraph(graph))
}
- protected[streaming] def remember(duration: Duration) {
+ private[streaming] def remember(duration: Duration) {
if (duration != null && duration > rememberDuration) {
rememberDuration = duration
logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
@@ -253,14 +253,14 @@ abstract class DStream[T: ClassTag] (
}
/** Checks whether the 'time' is valid wrt slideDuration for generating RDD */
- protected def isTimeValid(time: Time): Boolean = {
+ private[streaming] def isTimeValid(time: Time): Boolean = {
if (!isInitialized) {
throw new Exception (this + " has not been initialized")
} else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) {
logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime))
false
} else {
- logInfo("Time " + time + " is valid")
+ logDebug("Time " + time + " is valid")
true
}
}
@@ -269,7 +269,7 @@ abstract class DStream[T: ClassTag] (
* Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
* method that should not be called directly.
*/
- protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
+ private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
// If this DStream was not initialized (i.e., zeroTime not set), then do it
// If RDD was already generated, then retrieve it from HashMap
generatedRDDs.get(time) match {
@@ -310,7 +310,7 @@ abstract class DStream[T: ClassTag] (
* that materializes the corresponding RDD. Subclasses of DStream may override this
* to generate their own jobs.
*/
- protected[streaming] def generateJob(time: Time): Option[Job] = {
+ private[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
case Some(rdd) => {
val jobFunc = () => {
@@ -329,7 +329,7 @@ abstract class DStream[T: ClassTag] (
* implementation clears the old generated RDDs. Subclasses of DStream may override
* this to clear their own metadata along with the generated RDDs.
*/
- protected[streaming] def clearMetadata(time: Time) {
+ private[streaming] def clearMetadata(time: Time) {
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
generatedRDDs --= oldRDDs.keys
logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +
@@ -338,9 +338,9 @@ abstract class DStream[T: ClassTag] (
}
/* Adds metadata to the Stream while it is running.
- * This methd should be overwritten by sublcasses of InputDStream.
+ * This method should be overwritten by sublcasses of InputDStream.
*/
- protected[streaming] def addMetadata(metadata: Any) {
+ private[streaming] def addMetadata(metadata: Any) {
if (metadata != null) {
logInfo("Dropping Metadata: " + metadata.toString)
}
@@ -353,14 +353,14 @@ abstract class DStream[T: ClassTag] (
* checkpointData. Subclasses of DStream (especially those of InputDStream) may override
* this method to save custom checkpoint data.
*/
- protected[streaming] def updateCheckpointData(currentTime: Time) {
+ private[streaming] def updateCheckpointData(currentTime: Time) {
logInfo("Updating checkpoint data for time " + currentTime)
checkpointData.update(currentTime)
dependencies.foreach(_.updateCheckpointData(currentTime))
logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData)
}
- protected[streaming] def clearCheckpointData(time: Time) {
+ private[streaming] def clearCheckpointData(time: Time) {
logInfo("Clearing checkpoint data")
checkpointData.cleanup(time)
dependencies.foreach(_.clearCheckpointData(time))
@@ -373,7 +373,7 @@ abstract class DStream[T: ClassTag] (
* from the checkpoint file names stored in checkpointData. Subclasses of DStream that
* override the updateCheckpointData() method would also need to override this method.
*/
- protected[streaming] def restoreCheckpointData() {
+ private[streaming] def restoreCheckpointData() {
// Create RDDs from the checkpoint data
logInfo("Restoring checkpoint data")
checkpointData.restore()
@@ -684,7 +684,7 @@ abstract class DStream[T: ClassTag] (
/**
* Return all the RDDs defined by the Interval object (both end times included)
*/
- protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = {
+ def slice(interval: Interval): Seq[RDD[T]] = {
slice(interval.beginTime, interval.endTime)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index eee9591ffc..668e5324e6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming
-import dstream.InputDStream
+import org.apache.spark.streaming.dstream.{NetworkInputDStream, InputDStream}
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import org.apache.spark.Logging
@@ -103,6 +103,12 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
def getOutputStreams() = this.synchronized { outputStreams.toArray }
+ def getNetworkInputStreams() = this.synchronized {
+ inputStreams.filter(_.isInstanceOf[NetworkInputDStream[_]])
+ .map(_.asInstanceOf[NetworkInputDStream[_]])
+ .toArray
+ }
+
def generateJobs(time: Time): Seq[Job] = {
logDebug("Generating jobs for time " + time)
val jobs = this.synchronized {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index dd34f6f4f2..b20dbdd8cc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -46,7 +46,7 @@ import org.apache.hadoop.conf.Configuration
* information (such as, cluster URL and job name) to internally create a SparkContext, it provides
* methods used to create DStream from various input sources.
*/
-class StreamingContext private (
+class StreamingContext private[streaming] (
sc_ : SparkContext,
cp_ : Checkpoint,
batchDur_ : Duration
@@ -101,20 +101,9 @@ class StreamingContext private (
"both SparkContext and checkpoint as null")
}
- private val conf_ = Option(sc_).map(_.conf).getOrElse(cp_.sparkConf)
+ private[streaming] val isCheckpointPresent = (cp_ != null)
- if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds(conf_) < 0) {
- MetadataCleaner.setDelaySeconds(conf_, cp_.delaySeconds)
- }
-
- if (MetadataCleaner.getDelaySeconds(conf_) < 0) {
- throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
- + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
- }
-
- protected[streaming] val isCheckpointPresent = (cp_ != null)
-
- protected[streaming] val sc: SparkContext = {
+ private[streaming] val sc: SparkContext = {
if (isCheckpointPresent) {
new SparkContext(cp_.sparkConf)
} else {
@@ -122,11 +111,16 @@ class StreamingContext private (
}
}
- protected[streaming] val conf = sc.conf
+ if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) {
+ throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
+ + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
+ }
+
+ private[streaming] val conf = sc.conf
- protected[streaming] val env = SparkEnv.get
+ private[streaming] val env = SparkEnv.get
- protected[streaming] val graph: DStreamGraph = {
+ private[streaming] val graph: DStreamGraph = {
if (isCheckpointPresent) {
cp_.graph.setContext(this)
cp_.graph.restoreCheckpointData()
@@ -139,10 +133,9 @@ class StreamingContext private (
}
}
- protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0)
- protected[streaming] var networkInputTracker: NetworkInputTracker = null
+ private val nextNetworkInputStreamId = new AtomicInteger(0)
- protected[streaming] var checkpointDir: String = {
+ private[streaming] var checkpointDir: String = {
if (isCheckpointPresent) {
sc.setCheckpointDir(cp_.checkpointDir)
cp_.checkpointDir
@@ -151,11 +144,13 @@ class StreamingContext private (
}
}
- protected[streaming] val checkpointDuration: Duration = {
+ private[streaming] val checkpointDuration: Duration = {
if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration
}
- protected[streaming] val scheduler = new JobScheduler(this)
+ private[streaming] val scheduler = new JobScheduler(this)
+
+ private[streaming] val waiter = new ContextWaiter
/**
* Return the associated Spark context
*/
@@ -191,11 +186,11 @@ class StreamingContext private (
}
}
- protected[streaming] def initialCheckpoint: Checkpoint = {
+ private[streaming] def initialCheckpoint: Checkpoint = {
if (isCheckpointPresent) cp_ else null
}
- protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+ private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
/**
* Create an input stream with any arbitrary user implemented network receiver.
@@ -416,7 +411,7 @@ class StreamingContext private (
scheduler.listenerBus.addListener(streamingListener)
}
- protected def validate() {
+ private def validate() {
assert(graph != null, "Graph is null")
graph.validate()
@@ -430,38 +425,36 @@ class StreamingContext private (
/**
* Start the execution of the streams.
*/
- def start() {
+ def start() = synchronized {
validate()
+ scheduler.start()
+ }
- // Get the network input streams
- val networkInputStreams = graph.getInputStreams().filter(s => s match {
- case n: NetworkInputDStream[_] => true
- case _ => false
- }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
-
- // Start the network input tracker (must start before receivers)
- if (networkInputStreams.length > 0) {
- networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
- networkInputTracker.start()
- }
- Thread.sleep(1000)
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown here.
+ */
+ def waitForStop() {
+ waiter.waitForStopOrError()
+ }
- // Start the scheduler
- scheduler.start()
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown here.
+ * @param timeout time to wait
+ */
+ def waitForStop(timeout: Long) {
+ waiter.waitForStopOrError(timeout)
}
/**
* Stop the execution of the streams.
*/
- def stop() {
- try {
- if (scheduler != null) scheduler.stop()
- if (networkInputTracker != null) networkInputTracker.stop()
- sc.stop()
- logInfo("StreamingContext stopped successfully")
- } catch {
- case e: Exception => logWarning("Error while stopping", e)
- }
+ def stop(stopSparkContext: Boolean = true) = synchronized {
+ scheduler.stop()
+ logInfo("StreamingContext stopped successfully")
+ waiter.notifyStop()
+ if (stopSparkContext) sc.stop()
}
}
@@ -472,6 +465,8 @@ class StreamingContext private (
object StreamingContext extends Logging {
+ private[streaming] val DEFAULT_CLEANER_TTL = 3600
+
implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
@@ -515,37 +510,29 @@ object StreamingContext extends Logging {
*/
def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls)
-
- protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
+ private[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second batch intervals.
if (MetadataCleaner.getDelaySeconds(conf) < 0) {
- MetadataCleaner.setDelaySeconds(conf, 3600)
+ MetadataCleaner.setDelaySeconds(conf, DEFAULT_CLEANER_TTL)
}
val sc = new SparkContext(conf)
sc
}
- protected[streaming] def createNewSparkContext(
+ private[streaming] def createNewSparkContext(
master: String,
appName: String,
sparkHome: String,
jars: Seq[String],
environment: Map[String, String]
): SparkContext = {
-
val conf = SparkContext.updatedConf(
new SparkConf(), master, appName, sparkHome, jars, environment)
- // Set the default cleaner delay to an hour if not already set.
- // This should be sufficient for even 1 second batch intervals.
- if (MetadataCleaner.getDelaySeconds(conf) < 0) {
- MetadataCleaner.setDelaySeconds(conf, 3600)
- }
- val sc = new SparkContext(master, appName, sparkHome, jars, environment)
- sc
+ createNewSparkContext(conf)
}
- protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
+ private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
if (prefix == null) {
time.milliseconds.toString
} else if (suffix == null || suffix.length ==0) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index f01e67fe13..8f84232cab 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -43,7 +43,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
* This ensures that InputDStream.compute() is called strictly on increasing
* times.
*/
- override protected def isTimeValid(time: Time): Boolean = {
+ override private[streaming] def isTimeValid(time: Time): Boolean = {
if (!super.isTimeValid(time)) {
false // Time not valid
} else {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index d41f726f83..0f1f6fc2ce 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -68,7 +68,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
// then this returns an empty RDD. This may happen when recovering from a
// master failure
if (validTime >= graph.startTime) {
- val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
+ val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index c8ee93bf5b..7e0f6b2cdf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.Time
+import scala.util.Try
/**
* Class representing a Spark computation. It may contain multiple Spark jobs.
@@ -25,12 +26,10 @@ import org.apache.spark.streaming.Time
private[streaming]
class Job(val time: Time, func: () => _) {
var id: String = _
+ var result: Try[_] = null
- def run(): Long = {
- val startTime = System.currentTimeMillis
- func()
- val stopTime = System.currentTimeMillis
- (stopTime - startTime)
+ def run() {
+ result = Try(func())
}
def setId(number: Int) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 2fa6853ae0..caed1b3755 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -22,6 +22,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.Logging
import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
+import scala.util.{Failure, Success, Try}
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
@@ -37,9 +38,9 @@ private[scheduler] case class ClearCheckpointData(time: Time) extends JobGenerat
private[streaming]
class JobGenerator(jobScheduler: JobScheduler) extends Logging {
- val ssc = jobScheduler.ssc
- val graph = ssc.graph
- val eventProcessorActor = ssc.env.actorSystem.actorOf(Props(new Actor {
+ private val ssc = jobScheduler.ssc
+ private val graph = ssc.graph
+ private val eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
def receive = {
case event: JobGeneratorEvent =>
logDebug("Got event of type " + event.getClass.getName)
@@ -51,9 +52,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
Class.forName(clockClass).newInstance().asInstanceOf[Clock]
}
- val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
- longTime => eventProcessorActor ! GenerateJobs(new Time(longTime)))
- lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
+ private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
+ longTime => eventActor ! GenerateJobs(new Time(longTime)))
+ private lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
} else {
null
@@ -77,12 +78,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/**
* On batch completion, clear old metadata and checkpoint computation.
*/
- private[scheduler] def onBatchCompletion(time: Time) {
- eventProcessorActor ! ClearMetadata(time)
+ def onBatchCompletion(time: Time) {
+ eventActor ! ClearMetadata(time)
}
- private[streaming] def onCheckpointCompletion(time: Time) {
- eventProcessorActor ! ClearCheckpointData(time)
+ def onCheckpointCompletion(time: Time) {
+ eventActor ! ClearCheckpointData(time)
}
/** Processes all events */
@@ -121,14 +122,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val checkpointTime = ssc.initialCheckpoint.checkpointTime
val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds))
val downTimes = checkpointTime.until(restartTime, batchDuration)
- logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", "))
+ logInfo("Batches during down time (" + downTimes.size + " batches): "
+ + downTimes.mkString(", "))
// Batches that were unprocessed before failure
val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering)
- logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", "))
+ logInfo("Batches pending processing (" + pendingTimes.size + " batches): " +
+ pendingTimes.mkString(", "))
// Reschedule jobs for these times
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
- logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", "))
+ logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " +
+ timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
jobScheduler.runJobs(time, graph.generateJobs(time))
)
@@ -141,15 +145,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
- logInfo("\n-----------------------------------------------------\n")
- jobScheduler.runJobs(time, graph.generateJobs(time))
- eventProcessorActor ! DoCheckpoint(time)
+ Try(graph.generateJobs(time)) match {
+ case Success(jobs) => jobScheduler.runJobs(time, jobs)
+ case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e)
+ }
+ eventActor ! DoCheckpoint(time)
}
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- eventProcessorActor ! DoCheckpoint(time)
+ eventActor ! DoCheckpoint(time)
}
/** Clear DStream checkpoint data for the given `time`. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 30c070c274..b28ff5d9d8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -17,36 +17,60 @@
package org.apache.spark.streaming.scheduler
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import scala.util.{Failure, Success, Try}
+import scala.collection.JavaConversions._
import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors}
-import scala.collection.mutable.HashSet
+import akka.actor.{ActorRef, Actor, Props}
+import org.apache.spark.{SparkException, Logging, SparkEnv}
import org.apache.spark.streaming._
+
+private[scheduler] sealed trait JobSchedulerEvent
+private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent
+private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent
+private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent
+
/**
* This class schedules jobs to be run on Spark. It uses the JobGenerator to generate
- * the jobs and runs them using a thread pool. Number of threads
+ * the jobs and runs them using a thread pool.
*/
private[streaming]
class JobScheduler(val ssc: StreamingContext) extends Logging {
- val jobSets = new ConcurrentHashMap[Time, JobSet]
- val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
- val executor = Executors.newFixedThreadPool(numConcurrentJobs)
- val generator = new JobGenerator(this)
+ private val jobSets = new ConcurrentHashMap[Time, JobSet]
+ private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
+ private val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+ private val jobGenerator = new JobGenerator(this)
+ private val eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
+ def receive = {
+ case event: JobSchedulerEvent => processEvent(event)
+ }
+ }))
+ val clock = jobGenerator.clock // used by testsuites
val listenerBus = new StreamingListenerBus()
- def clock = generator.clock
+ var networkInputTracker: NetworkInputTracker = null
- def start() {
- generator.start()
+ def start() = synchronized {
+ if (networkInputTracker != null) {
+ throw new SparkException("StreamingContext already started")
+ }
+ networkInputTracker = new NetworkInputTracker(ssc)
+ networkInputTracker.start()
+ Thread.sleep(1000)
+ jobGenerator.start()
+ logInfo("JobScheduler started")
}
- def stop() {
- generator.stop()
- executor.shutdown()
- if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
- executor.shutdownNow()
+ def stop() = synchronized {
+ if (eventActor != null) {
+ jobGenerator.stop()
+ networkInputTracker.stop()
+ executor.shutdown()
+ if (!executor.awaitTermination(2, TimeUnit.SECONDS)) {
+ executor.shutdownNow()
+ }
+ logInfo("JobScheduler stopped")
}
}
@@ -61,46 +85,68 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
}
}
- def getPendingTimes(): Array[Time] = {
- jobSets.keySet.toArray(new Array[Time](0))
+ def getPendingTimes(): Seq[Time] = {
+ jobSets.keySet.toSeq
+ }
+
+ def reportError(msg: String, e: Throwable) {
+ eventActor ! ErrorReported(msg, e)
+ }
+
+ private def processEvent(event: JobSchedulerEvent) {
+ try {
+ event match {
+ case JobStarted(job) => handleJobStart(job)
+ case JobCompleted(job) => handleJobCompletion(job)
+ case ErrorReported(m, e) => handleError(m, e)
+ }
+ } catch {
+ case e: Throwable =>
+ reportError("Error in job scheduler", e)
+ }
+
}
- private def beforeJobStart(job: Job) {
+ private def handleJobStart(job: Job) {
val jobSet = jobSets.get(job.time)
if (!jobSet.hasStarted) {
- listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo()))
+ listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo))
}
- jobSet.beforeJobStart(job)
+ jobSet.handleJobStart(job)
logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
- SparkEnv.set(generator.ssc.env)
+ SparkEnv.set(ssc.env)
}
- private def afterJobEnd(job: Job) {
- val jobSet = jobSets.get(job.time)
- jobSet.afterJobStop(job)
- logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
- if (jobSet.hasCompleted) {
- jobSets.remove(jobSet.time)
- generator.onBatchCompletion(jobSet.time)
- logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
- jobSet.totalDelay / 1000.0, jobSet.time.toString,
- jobSet.processingDelay / 1000.0
- ))
- listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo()))
+ private def handleJobCompletion(job: Job) {
+ job.result match {
+ case Success(_) =>
+ val jobSet = jobSets.get(job.time)
+ jobSet.handleJobCompletion(job)
+ logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
+ if (jobSet.hasCompleted) {
+ jobSets.remove(jobSet.time)
+ jobGenerator.onBatchCompletion(jobSet.time)
+ logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
+ jobSet.totalDelay / 1000.0, jobSet.time.toString,
+ jobSet.processingDelay / 1000.0
+ ))
+ listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo))
+ }
+ case Failure(e) =>
+ reportError("Error running job " + job, e)
}
}
- private[streaming]
- class JobHandler(job: Job) extends Runnable {
+ private def handleError(msg: String, e: Throwable) {
+ logError(msg, e)
+ ssc.waiter.notifyError(e)
+ }
+
+ private class JobHandler(job: Job) extends Runnable {
def run() {
- beforeJobStart(job)
- try {
- job.run()
- } catch {
- case e: Exception =>
- logError("Running " + job + " failed", e)
- }
- afterJobEnd(job)
+ eventActor ! JobStarted(job)
+ job.run()
+ eventActor ! JobCompleted(job)
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
index 57268674ea..fcf303aee6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.scheduler
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.streaming.Time
/** Class representing a set of Jobs
@@ -27,25 +27,25 @@ private[streaming]
case class JobSet(time: Time, jobs: Seq[Job]) {
private val incompleteJobs = new HashSet[Job]()
- var submissionTime = System.currentTimeMillis() // when this jobset was submitted
- var processingStartTime = -1L // when the first job of this jobset started processing
- var processingEndTime = -1L // when the last job of this jobset finished processing
+ private val submissionTime = System.currentTimeMillis() // when this jobset was submitted
+ private var processingStartTime = -1L // when the first job of this jobset started processing
+ private var processingEndTime = -1L // when the last job of this jobset finished processing
jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) }
incompleteJobs ++= jobs
- def beforeJobStart(job: Job) {
+ def handleJobStart(job: Job) {
if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
}
- def afterJobStop(job: Job) {
+ def handleJobCompletion(job: Job) {
incompleteJobs -= job
if (hasCompleted) processingEndTime = System.currentTimeMillis()
}
- def hasStarted() = (processingStartTime > 0)
+ def hasStarted = processingStartTime > 0
- def hasCompleted() = incompleteJobs.isEmpty
+ def hasCompleted = incompleteJobs.isEmpty
// Time taken to process all the jobs from the time they started processing
// (i.e. not including the time they wait in the streaming scheduler queue)
@@ -57,7 +57,7 @@ case class JobSet(time: Time, jobs: Seq[Job]) {
processingEndTime - time.milliseconds
}
- def toBatchInfo(): BatchInfo = {
+ def toBatchInfo: BatchInfo = {
new BatchInfo(
time,
submissionTime,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index 75f7244643..34fb158205 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -42,11 +42,9 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) ext
* This class manages the execution of the receivers of NetworkInputDStreams.
*/
private[streaming]
-class NetworkInputTracker(
- @transient ssc: StreamingContext,
- @transient networkInputStreams: Array[NetworkInputDStream[_]])
- extends Logging {
+class NetworkInputTracker(ssc: StreamingContext) extends Logging {
+ val networkInputStreams = ssc.graph.getNetworkInputStreams()
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
@@ -57,15 +55,19 @@ class NetworkInputTracker(
/** Start the actor and receiver execution thread. */
def start() {
- ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
- receiverExecutor.start()
+ if (!networkInputStreams.isEmpty) {
+ ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
+ receiverExecutor.start()
+ }
}
/** Stop the receiver execution thread. */
def stop() {
- // TODO: stop the actor as well
- receiverExecutor.interrupt()
- receiverExecutor.stopReceivers()
+ if (!networkInputStreams.isEmpty) {
+ receiverExecutor.interrupt()
+ receiverExecutor.stopReceivers()
+ logInfo("NetworkInputTracker stopped")
+ }
}
/** Return all the blocks received from a receiver. */
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index ee6b433d1f..2e3a1e66ad 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -375,11 +375,7 @@ class BasicOperationsSuite extends TestSuiteBase {
}
test("slice") {
- val conf2 = new SparkConf()
- .setMaster("local[2]")
- .setAppName("BasicOperationsSuite")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
- val ssc = new StreamingContext(new SparkContext(conf2), Seconds(1))
+ val ssc = new StreamingContext(conf, Seconds(1))
val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
val stream = new TestInputStream[Int](ssc, input, 2)
ssc.registerInputStream(stream)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 6499de98c9..9590bca989 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -28,6 +28,8 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.FileInputDStream
import org.apache.spark.streaming.util.ManualClock
+import org.apache.spark.util.Utils
+import org.apache.spark.SparkConf
/**
* This test suites tests the checkpointing functionality of DStreams -
@@ -142,6 +144,26 @@ class CheckpointSuite extends TestSuiteBase {
ssc = null
}
+ // This tests whether spark conf persists through checkpoints, and certain
+ // configs gets scrubbed
+ test("persistence of conf through checkpoints") {
+ val key = "spark.mykey"
+ val value = "myvalue"
+ System.setProperty(key, value)
+ ssc = new StreamingContext(master, framework, batchDuration)
+ val cp = new Checkpoint(ssc, Time(1000))
+ assert(!cp.sparkConf.contains("spark.driver.host"))
+ assert(!cp.sparkConf.contains("spark.driver.port"))
+ assert(!cp.sparkConf.contains("spark.hostPort"))
+ assert(cp.sparkConf.get(key) === value)
+ ssc.stop()
+ val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
+ assert(!newCp.sparkConf.contains("spark.driver.host"))
+ assert(!newCp.sparkConf.contains("spark.driver.port"))
+ assert(!newCp.sparkConf.contains("spark.hostPort"))
+ assert(newCp.sparkConf.get(key) === value)
+ }
+
// This tests whether the systm can recover from a master failure with simple
// non-stateful operations. This assumes as reliable, replayable input
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
new file mode 100644
index 0000000000..9eb9b3684c
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+import org.apache.spark.{SparkException, SparkConf, SparkContext}
+import org.apache.spark.util.{Utils, MetadataCleaner}
+
+class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
+
+ val master = "local[2]"
+ val appName = this.getClass.getSimpleName
+ val batchDuration = Seconds(1)
+ val sparkHome = "someDir"
+ val envPair = "key" -> "value"
+ val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100
+
+ var sc: SparkContext = null
+ var ssc: StreamingContext = null
+
+ before {
+ System.clearProperty("spark.cleaner.ttl")
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ ssc = null
+ }
+ }
+
+ test("from no conf constructor") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ assert(ssc.sparkContext.conf.get("spark.master") === master)
+ assert(ssc.sparkContext.conf.get("spark.app.name") === appName)
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from no conf + spark home") {
+ ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil)
+ assert(ssc.conf.get("spark.home") === sparkHome)
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from no conf + spark home + env") {
+ ssc = new StreamingContext(master, appName, batchDuration,
+ sparkHome, Nil, Map(envPair))
+ assert(ssc.conf.getExecutorEnv.exists(_ == envPair))
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from conf without ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(MetadataCleaner.getDelaySeconds(ssc.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from conf with ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl)
+ }
+
+ test("from existing SparkContext without ttl set") {
+ sc = new SparkContext(master, appName)
+ val exception = intercept[SparkException] {
+ ssc = new StreamingContext(sc, batchDuration)
+ }
+ assert(exception.getMessage.contains("ttl"))
+ }
+
+ test("from existing SparkContext with ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl)
+ }
+
+ test("from checkpoint") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ val ssc1 = new StreamingContext(myConf, batchDuration)
+ val cp = new Checkpoint(ssc1, Time(1000))
+ assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl)
+ ssc1.stop()
+ val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
+ assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl)
+ ssc = new StreamingContext(null, cp, null)
+ assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl)
+ }
+
+ test("start multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+
+ ssc.start()
+ intercept[SparkException] {
+ ssc.start()
+ }
+ }
+
+ test("stop multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ ssc.stop()
+ ssc.stop()
+ ssc = null
+ }
+
+ test("stop only streaming context") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ ssc.stop(false)
+ ssc = null
+ assert(sc.makeRDD(1 to 100).collect().size === 100)
+ }
+
+ test("waitForStop") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => x).register
+
+ // test whether start() blocks indefinitely or not
+ failAfter(2000 millis) {
+ ssc.start()
+ }
+
+ // test whether waitForStop() exits after give amount of time
+ failAfter(1000 millis) {
+ ssc.waitForStop(500)
+ }
+
+ // test whether waitForStop() does not exit if not time is given
+ val exception = intercept[Exception] {
+ failAfter(1000 millis) {
+ ssc.waitForStop()
+ throw new Exception("Did not wait for stop")
+ }
+ }
+ assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop")
+
+ // test whether wait exits if context is stopped
+ failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown
+ new Thread() {
+ override def run {
+ Thread.sleep(500)
+ ssc.stop()
+ }
+ }.start()
+ ssc.waitForStop()
+ }
+ }
+
+ test("waitForStop with error in task") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => { throw new TestException("error in map task"); x})
+ .foreach(_.count)
+
+ val exception = intercept[Exception] {
+ ssc.start()
+ ssc.waitForStop(5000)
+ }
+ assert(exception.getMessage.contains("map task"), "Expected exception not thrown")
+ }
+
+ test("waitForStop with error in job generation") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+
+ inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register
+ val exception = intercept[TestException] {
+ ssc.start()
+ ssc.waitForStop(5000)
+ }
+ assert(exception.getMessage.contains("transform"), "Expected exception not thrown")
+ }
+
+ def addInputStream(s: StreamingContext): DStream[Int] = {
+ val input = (1 to 100).map(i => (1 to i))
+ val inputStream = new TestInputStream(s, input, 1)
+ s.registerInputStream(inputStream)
+ inputStream
+ }
+}
+
+class TestException(msg: String) extends Exception(msg) \ No newline at end of file
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index b20d02f996..3569624d51 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -137,7 +137,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val conf = new SparkConf()
.setMaster(master)
.setAppName(framework)
- .set("spark.cleaner.ttl", "3600")
+ .set("spark.cleaner.ttl", StreamingContext.DEFAULT_CLEANER_TTL.toString)
// Default before function for any streaming test suite. Override this
// if you want to add your stuff to "before" (i.e., don't call before { } )