diff options
Diffstat (limited to 'streaming')
26 files changed, 653 insertions, 256 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index f586baee0f..5046a1d53f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -40,7 +40,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val graph = ssc.graph val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration - val pendingTimes = ssc.scheduler.getPendingTimes() + val pendingTimes = ssc.scheduler.getPendingTimes().toArray val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConf = ssc.conf @@ -271,6 +271,6 @@ class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoade } catch { case e: Exception => } - return super.resolveClass(desc) + super.resolveClass(desc) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala new file mode 100644 index 0000000000..1f5dacb543 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala @@ -0,0 +1,28 @@ +package org.apache.spark.streaming + +private[streaming] class ContextWaiter { + private var error: Throwable = null + private var stopped: Boolean = false + + def notifyError(e: Throwable) = synchronized { + error = e + notifyAll() + } + + def notifyStop() = synchronized { + notifyAll() + } + + def waitForStopOrError(timeout: Long = -1) = synchronized { + // If already had error, then throw it + if (error != null) { + throw error + } + + // If not already stopped, then wait + if (!stopped) { + if (timeout < 0) wait() else wait(timeout) + if (error != null) throw error + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index b98f4a5101..f760093579 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -17,19 +17,20 @@ package org.apache.spark.streaming -import StreamingContext._ -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.MetadataCleaner +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import scala.deprecated import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import StreamingContext._ +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.scheduler.Job +import org.apache.spark.util.MetadataCleaner /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -53,7 +54,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} */ abstract class DStream[T: ClassTag] ( - @transient protected[streaming] var ssc: StreamingContext + @transient private[streaming] var ssc: StreamingContext ) extends Serializable with Logging { // ======================================================================= @@ -73,31 +74,31 @@ abstract class DStream[T: ClassTag] ( // Methods and fields available on all DStreams // ======================================================================= - // RDDs generated, marked as protected[streaming] so that testsuites can access it + // RDDs generated, marked as private[streaming] so that testsuites can access it @transient - protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () + private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected[streaming] var zeroTime: Time = null + private[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected[streaming] var rememberDuration: Duration = null + private[streaming] var rememberDuration: Duration = null // Storage level of the RDDs in the stream - protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE + private[streaming] var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint details - protected[streaming] val mustCheckpoint = false - protected[streaming] var checkpointDuration: Duration = null - protected[streaming] val checkpointData = new DStreamCheckpointData(this) + private[streaming] val mustCheckpoint = false + private[streaming] var checkpointDuration: Duration = null + private[streaming] val checkpointData = new DStreamCheckpointData(this) // Reference to whole DStream graph - protected[streaming] var graph: DStreamGraph = null + private[streaming] var graph: DStreamGraph = null - protected[streaming] def isInitialized = (zeroTime != null) + private[streaming] def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created - protected[streaming] def parentRememberDuration = rememberDuration + private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ def context = ssc @@ -137,7 +138,7 @@ abstract class DStream[T: ClassTag] ( * the validity of future times is calculated. This method also recursively initializes * its parent DStreams. */ - protected[streaming] def initialize(time: Time) { + private[streaming] def initialize(time: Time) { if (zeroTime != null && zeroTime != time) { throw new Exception("ZeroTime is already initialized to " + zeroTime + ", cannot initialize it again to " + time) @@ -163,7 +164,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.initialize(zeroTime)) } - protected[streaming] def validate() { + private[streaming] def validate() { assert(rememberDuration != null, "Remember duration is set to null") assert( @@ -227,7 +228,7 @@ abstract class DStream[T: ClassTag] ( logInfo("Initialized and validated " + this) } - protected[streaming] def setContext(s: StreamingContext) { + private[streaming] def setContext(s: StreamingContext) { if (ssc != null && ssc != s) { throw new Exception("Context is already set in " + this + ", cannot set it again") } @@ -236,7 +237,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.setContext(ssc)) } - protected[streaming] def setGraph(g: DStreamGraph) { + private[streaming] def setGraph(g: DStreamGraph) { if (graph != null && graph != g) { throw new Exception("Graph is already set in " + this + ", cannot set it again") } @@ -244,7 +245,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def remember(duration: Duration) { + private[streaming] def remember(duration: Duration) { if (duration != null && duration > rememberDuration) { rememberDuration = duration logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) @@ -253,14 +254,14 @@ abstract class DStream[T: ClassTag] ( } /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */ - protected def isTimeValid(time: Time): Boolean = { + private[streaming] def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) false } else { - logInfo("Time " + time + " is valid") + logDebug("Time " + time + " is valid") true } } @@ -269,7 +270,7 @@ abstract class DStream[T: ClassTag] ( * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal * method that should not be called directly. */ - protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { + private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { // If this DStream was not initialized (i.e., zeroTime not set), then do it // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { @@ -310,7 +311,7 @@ abstract class DStream[T: ClassTag] ( * that materializes the corresponding RDD. Subclasses of DStream may override this * to generate their own jobs. */ - protected[streaming] def generateJob(time: Time): Option[Job] = { + private[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { val jobFunc = () => { @@ -329,7 +330,7 @@ abstract class DStream[T: ClassTag] ( * implementation clears the old generated RDDs. Subclasses of DStream may override * this to clear their own metadata along with the generated RDDs. */ - protected[streaming] def clearMetadata(time: Time) { + private[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) generatedRDDs --= oldRDDs.keys logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + @@ -338,9 +339,9 @@ abstract class DStream[T: ClassTag] ( } /* Adds metadata to the Stream while it is running. - * This methd should be overwritten by sublcasses of InputDStream. + * This method should be overwritten by sublcasses of InputDStream. */ - protected[streaming] def addMetadata(metadata: Any) { + private[streaming] def addMetadata(metadata: Any) { if (metadata != null) { logInfo("Dropping Metadata: " + metadata.toString) } @@ -353,18 +354,18 @@ abstract class DStream[T: ClassTag] ( * checkpointData. Subclasses of DStream (especially those of InputDStream) may override * this method to save custom checkpoint data. */ - protected[streaming] def updateCheckpointData(currentTime: Time) { - logInfo("Updating checkpoint data for time " + currentTime) + private[streaming] def updateCheckpointData(currentTime: Time) { + logDebug("Updating checkpoint data for time " + currentTime) checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) } - protected[streaming] def clearCheckpointData(time: Time) { - logInfo("Clearing checkpoint data") + private[streaming] def clearCheckpointData(time: Time) { + logDebug("Clearing checkpoint data") checkpointData.cleanup(time) dependencies.foreach(_.clearCheckpointData(time)) - logInfo("Cleared checkpoint data") + logDebug("Cleared checkpoint data") } /** @@ -373,7 +374,7 @@ abstract class DStream[T: ClassTag] ( * from the checkpoint file names stored in checkpointData. Subclasses of DStream that * override the updateCheckpointData() method would also need to override this method. */ - protected[streaming] def restoreCheckpointData() { + private[streaming] def restoreCheckpointData() { // Create RDDs from the checkpoint data logInfo("Restoring checkpoint data") checkpointData.restore() @@ -487,15 +488,29 @@ abstract class DStream[T: ClassTag] ( * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreach(foreachFunc: RDD[T] => Unit) { - this.foreach((r: RDD[T], t: Time) => foreachFunc(r)) + @deprecated("use foreachRDD", "0.9.0") + def foreach(foreachFunc: RDD[T] => Unit) = this.foreachRDD(foreachFunc) + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + @deprecated("use foreachRDD", "0.9.0") + def foreach(foreachFunc: (RDD[T], Time) => Unit) = this.foreachRDD(foreachFunc) + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: RDD[T] => Unit) { + this.foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreach(foreachFunc: (RDD[T], Time) => Unit) { + def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { ssc.registerOutputStream(new ForEachDStream(this, context.sparkContext.clean(foreachFunc))) } @@ -684,7 +699,7 @@ abstract class DStream[T: ClassTag] ( /** * Return all the RDDs defined by the Interval object (both end times included) */ - protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { + def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) } @@ -719,7 +734,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreach(saveFunc) + this.foreachRDD(saveFunc) } /** @@ -732,7 +747,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreach(saveFunc) + this.foreachRDD(saveFunc) } def register() { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index eee9591ffc..668e5324e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import dstream.InputDStream +import org.apache.spark.streaming.dstream.{NetworkInputDStream, InputDStream} import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -103,6 +103,12 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getNetworkInputStreams() = this.synchronized { + inputStreams.filter(_.isInstanceOf[NetworkInputDStream[_]]) + .map(_.asInstanceOf[NetworkInputDStream[_]]) + .toArray + } + def generateJobs(time: Time): Seq[Job] = { logDebug("Generating jobs for time " + time) val jobs = this.synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala index 56dbcbda23..69d80c3711 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala @@ -582,7 +582,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreach(saveFunc) + self.foreachRDD(saveFunc) } /** @@ -612,7 +612,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreach(saveFunc) + self.foreachRDD(saveFunc) } private def getKeyClass() = implicitly[ClassTag[K]].runtimeClass diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index dd34f6f4f2..ee83ae902b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -46,7 +46,7 @@ import org.apache.hadoop.conf.Configuration * information (such as, cluster URL and job name) to internally create a SparkContext, it provides * methods used to create DStream from various input sources. */ -class StreamingContext private ( +class StreamingContext private[streaming] ( sc_ : SparkContext, cp_ : Checkpoint, batchDur_ : Duration @@ -101,20 +101,9 @@ class StreamingContext private ( "both SparkContext and checkpoint as null") } - private val conf_ = Option(sc_).map(_.conf).getOrElse(cp_.sparkConf) + private[streaming] val isCheckpointPresent = (cp_ != null) - if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds(conf_) < 0) { - MetadataCleaner.setDelaySeconds(conf_, cp_.delaySeconds) - } - - if (MetadataCleaner.getDelaySeconds(conf_) < 0) { - throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " - + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") - } - - protected[streaming] val isCheckpointPresent = (cp_ != null) - - protected[streaming] val sc: SparkContext = { + private[streaming] val sc: SparkContext = { if (isCheckpointPresent) { new SparkContext(cp_.sparkConf) } else { @@ -122,11 +111,16 @@ class StreamingContext private ( } } - protected[streaming] val conf = sc.conf + if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { + throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " + + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") + } + + private[streaming] val conf = sc.conf - protected[streaming] val env = SparkEnv.get + private[streaming] val env = SparkEnv.get - protected[streaming] val graph: DStreamGraph = { + private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) cp_.graph.restoreCheckpointData() @@ -139,10 +133,9 @@ class StreamingContext private ( } } - protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) - protected[streaming] var networkInputTracker: NetworkInputTracker = null + private val nextNetworkInputStreamId = new AtomicInteger(0) - protected[streaming] var checkpointDir: String = { + private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { sc.setCheckpointDir(cp_.checkpointDir) cp_.checkpointDir @@ -151,11 +144,13 @@ class StreamingContext private ( } } - protected[streaming] val checkpointDuration: Duration = { + private[streaming] val checkpointDuration: Duration = { if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration } - protected[streaming] val scheduler = new JobScheduler(this) + private[streaming] val scheduler = new JobScheduler(this) + + private[streaming] val waiter = new ContextWaiter /** * Return the associated Spark context */ @@ -191,11 +186,11 @@ class StreamingContext private ( } } - protected[streaming] def initialCheckpoint: Checkpoint = { + private[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } - protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() /** * Create an input stream with any arbitrary user implemented network receiver. @@ -416,7 +411,7 @@ class StreamingContext private ( scheduler.listenerBus.addListener(streamingListener) } - protected def validate() { + private def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -430,38 +425,37 @@ class StreamingContext private ( /** * Start the execution of the streams. */ - def start() { + def start() = synchronized { validate() + scheduler.start() + } - // Get the network input streams - val networkInputStreams = graph.getInputStreams().filter(s => s match { - case n: NetworkInputDStream[_] => true - case _ => false - }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray - - // Start the network input tracker (must start before receivers) - if (networkInputStreams.length > 0) { - networkInputTracker = new NetworkInputTracker(this, networkInputStreams) - networkInputTracker.start() - } - Thread.sleep(1000) + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + */ + def awaitTermination() { + waiter.waitForStopOrError() + } - // Start the scheduler - scheduler.start() + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * @param timeout time to wait in milliseconds + */ + def awaitTermination(timeout: Long) { + waiter.waitForStopOrError(timeout) } /** * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not */ - def stop() { - try { - if (scheduler != null) scheduler.stop() - if (networkInputTracker != null) networkInputTracker.stop() - sc.stop() - logInfo("StreamingContext stopped successfully") - } catch { - case e: Exception => logWarning("Error while stopping", e) - } + def stop(stopSparkContext: Boolean = true) = synchronized { + scheduler.stop() + logInfo("StreamingContext stopped successfully") + waiter.notifyStop() + if (stopSparkContext) sc.stop() } } @@ -472,6 +466,8 @@ class StreamingContext private ( object StreamingContext extends Logging { + private[streaming] val DEFAULT_CLEANER_TTL = 3600 + implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } @@ -515,37 +511,29 @@ object StreamingContext extends Logging { */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls) - - protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { + private[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second batch intervals. if (MetadataCleaner.getDelaySeconds(conf) < 0) { - MetadataCleaner.setDelaySeconds(conf, 3600) + MetadataCleaner.setDelaySeconds(conf, DEFAULT_CLEANER_TTL) } val sc = new SparkContext(conf) sc } - protected[streaming] def createNewSparkContext( + private[streaming] def createNewSparkContext( master: String, appName: String, sparkHome: String, jars: Seq[String], environment: Map[String, String] ): SparkContext = { - val conf = SparkContext.updatedConf( new SparkConf(), master, appName, sparkHome, jars, environment) - // Set the default cleaner delay to an hour if not already set. - // This should be sufficient for even 1 second batch intervals. - if (MetadataCleaner.getDelaySeconds(conf) < 0) { - MetadataCleaner.setDelaySeconds(conf, 3600) - } - val sc = new SparkContext(master, appName, sparkHome, jars, environment) - sc + createNewSparkContext(conf) } - protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { time.milliseconds.toString } else if (suffix == null || suffix.length ==0) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 64f38ce1c0..cea4795eb5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -243,17 +243,39 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 0.9.0, replaced by foreachRDD */ + @Deprecated def foreach(foreachFunc: JFunction[R, Void]) { - dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd))) + foreachRDD(foreachFunc) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 0.9.0, replaced by foreachRDD */ + @Deprecated def foreach(foreachFunc: JFunction2[R, Time, Void]) { - dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + foreachRDD(foreachFunc) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JFunction[R, Void]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 523173d45a..b4c46f5e50 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -483,9 +483,28 @@ class JavaStreamingContext(val ssc: StreamingContext) { def start() = ssc.start() /** - * Stop the execution of the streams. + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + */ + def awaitTermination() = ssc.awaitTermination() + + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * @param timeout time to wait in milliseconds + */ + def awaitTermination(timeout: Long) = ssc.awaitTermination(timeout) + + /** + * Stop the execution of the streams. Will stop the associated JavaSparkContext as well. */ def stop() = ssc.stop() + + /** + * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not + */ + def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 1f0f31c4b1..f10d483634 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -239,7 +239,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas reset() return false } - return true + true } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index f01e67fe13..8f84232cab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -43,7 +43,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) * This ensures that InputDStream.compute() is called strictly on increasing * times. */ - override protected def isTimeValid(time: Time): Boolean = { + override private[streaming] def isTimeValid(time: Time): Boolean = { if (!super.isTimeValid(time)) { false // Time not valid } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index d41f726f83..0f1f6fc2ce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -68,7 +68,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte // then this returns an empty RDD. This may happen when recovering from a // master failure if (validTime >= graph.startTime) { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } else { Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index e0ff3ccba4..b34ba7b9b4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -65,7 +65,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) //logDebug("Generating state RDD for time " + validTime) - return Some(stateRDD) + Some(stateRDD) } case None => { // If parent RDD does not exist @@ -76,7 +76,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( updateFuncLocal(i) } val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) - return Some(stateRDD) + Some(stateRDD) } } } @@ -98,11 +98,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val groupedRDD = parentRDD.groupByKey(partitioner) val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) //logDebug("Generating state RDD for time " + validTime + " (first)") - return Some(sessionRDD) + Some(sessionRDD) } case None => { // If parent RDD does not exist, then nothing to do! //logDebug("Not generating state RDD (no previous state, no parent)") - return None + None } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index c8ee93bf5b..7e0f6b2cdf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.Time +import scala.util.Try /** * Class representing a Spark computation. It may contain multiple Spark jobs. @@ -25,12 +26,10 @@ import org.apache.spark.streaming.Time private[streaming] class Job(val time: Time, func: () => _) { var id: String = _ + var result: Try[_] = null - def run(): Long = { - val startTime = System.currentTimeMillis - func() - val stopTime = System.currentTimeMillis - (stopTime - startTime) + def run() { + result = Try(func()) } def setId(number: Int) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2fa6853ae0..b5f11d3440 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.scheduler -import akka.actor.{Props, Actor} -import org.apache.spark.SparkEnv -import org.apache.spark.Logging +import akka.actor.{ActorRef, ActorSystem, Props, Actor} +import org.apache.spark.{SparkException, SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} +import scala.util.{Failure, Success, Try} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -37,29 +37,38 @@ private[scheduler] case class ClearCheckpointData(time: Time) extends JobGenerat private[streaming] class JobGenerator(jobScheduler: JobScheduler) extends Logging { - val ssc = jobScheduler.ssc - val graph = ssc.graph - val eventProcessorActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { - case event: JobGeneratorEvent => - logDebug("Got event of type " + event.getClass.getName) - processEvent(event) - } - })) + private val ssc = jobScheduler.ssc + private val graph = ssc.graph val clock = { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") Class.forName(clockClass).newInstance().asInstanceOf[Clock] } - val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventProcessorActor ! GenerateJobs(new Time(longTime))) - lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { + private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, + longTime => eventActor ! GenerateJobs(new Time(longTime))) + private lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } + // eventActor is created when generator starts. + // This not being null means the scheduler has been started and not stopped + private var eventActor: ActorRef = null + + /** Start generation of jobs */ def start() = synchronized { + if (eventActor != null) { + throw new SparkException("JobGenerator already started") + } + + eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { + def receive = { + case event: JobGeneratorEvent => + logDebug("Got event of type " + event.getClass.getName) + processEvent(event) + } + }), "JobGenerator") if (ssc.isCheckpointPresent) { restart() } else { @@ -67,22 +76,26 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } - def stop() { - timer.stop() - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("JobGenerator stopped") + /** Stop generation of jobs */ + def stop() = synchronized { + if (eventActor != null) { + timer.stop() + ssc.env.actorSystem.stop(eventActor) + if (checkpointWriter != null) checkpointWriter.stop() + ssc.graph.stop() + logInfo("JobGenerator stopped") + } } /** * On batch completion, clear old metadata and checkpoint computation. */ - private[scheduler] def onBatchCompletion(time: Time) { - eventProcessorActor ! ClearMetadata(time) + def onBatchCompletion(time: Time) { + eventActor ! ClearMetadata(time) } - private[streaming] def onCheckpointCompletion(time: Time) { - eventProcessorActor ! ClearCheckpointData(time) + def onCheckpointCompletion(time: Time) { + eventActor ! ClearCheckpointData(time) } /** Processes all events */ @@ -121,14 +134,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val checkpointTime = ssc.initialCheckpoint.checkpointTime val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) val downTimes = checkpointTime.until(restartTime, batchDuration) - logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", ")) + logInfo("Batches during down time (" + downTimes.size + " batches): " + + downTimes.mkString(", ")) // Batches that were unprocessed before failure val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) - logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) + logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + + pendingTimes.mkString(", ")) // Reschedule jobs for these times val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) + logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => jobScheduler.runJobs(time, graph.generateJobs(time)) ) @@ -141,15 +157,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { SparkEnv.set(ssc.env) - logInfo("\n-----------------------------------------------------\n") - jobScheduler.runJobs(time, graph.generateJobs(time)) - eventProcessorActor ! DoCheckpoint(time) + Try(graph.generateJobs(time)) match { + case Success(jobs) => jobScheduler.runJobs(time, jobs) + case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + } + eventActor ! DoCheckpoint(time) } /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - eventProcessorActor ! DoCheckpoint(time) + eventActor ! DoCheckpoint(time) } /** Clear DStream checkpoint data for the given `time`. */ @@ -166,4 +184,3 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 30c070c274..de675d3c7f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,36 +17,68 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import scala.util.{Failure, Success, Try} +import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} -import scala.collection.mutable.HashSet +import akka.actor.{ActorRef, Actor, Props} +import org.apache.spark.{SparkException, Logging, SparkEnv} import org.apache.spark.streaming._ + +private[scheduler] sealed trait JobSchedulerEvent +private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent +private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent +private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent + /** * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate - * the jobs and runs them using a thread pool. Number of threads + * the jobs and runs them using a thread pool. */ private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { - val jobSets = new ConcurrentHashMap[Time, JobSet] - val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - val executor = Executors.newFixedThreadPool(numConcurrentJobs) - val generator = new JobGenerator(this) + private val jobSets = new ConcurrentHashMap[Time, JobSet] + private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) + private val executor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobGenerator = new JobGenerator(this) + val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() - def clock = generator.clock + // These two are created only when scheduler starts. + // eventActor not being null means the scheduler has been started and not stopped + var networkInputTracker: NetworkInputTracker = null + private var eventActor: ActorRef = null + + + def start() = synchronized { + if (eventActor != null) { + throw new SparkException("JobScheduler already started") + } - def start() { - generator.start() + eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { + def receive = { + case event: JobSchedulerEvent => processEvent(event) + } + }), "JobScheduler") + listenerBus.start() + networkInputTracker = new NetworkInputTracker(ssc) + networkInputTracker.start() + Thread.sleep(1000) + jobGenerator.start() + logInfo("JobScheduler started") } - def stop() { - generator.stop() - executor.shutdown() - if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { - executor.shutdownNow() + def stop() = synchronized { + if (eventActor != null) { + jobGenerator.stop() + networkInputTracker.stop() + executor.shutdown() + if (!executor.awaitTermination(2, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + listenerBus.stop() + ssc.env.actorSystem.stop(eventActor) + logInfo("JobScheduler stopped") } } @@ -61,46 +93,67 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } } - def getPendingTimes(): Array[Time] = { - jobSets.keySet.toArray(new Array[Time](0)) + def getPendingTimes(): Seq[Time] = { + jobSets.keySet.toSeq + } + + def reportError(msg: String, e: Throwable) { + eventActor ! ErrorReported(msg, e) } - private def beforeJobStart(job: Job) { + private def processEvent(event: JobSchedulerEvent) { + try { + event match { + case JobStarted(job) => handleJobStart(job) + case JobCompleted(job) => handleJobCompletion(job) + case ErrorReported(m, e) => handleError(m, e) + } + } catch { + case e: Throwable => + reportError("Error in job scheduler", e) + } + } + + private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) if (!jobSet.hasStarted) { - listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo())) + listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.beforeJobStart(job) + jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) - SparkEnv.set(generator.ssc.env) + SparkEnv.set(ssc.env) } - private def afterJobEnd(job: Job) { - val jobSet = jobSets.get(job.time) - jobSet.afterJobStop(job) - logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) - if (jobSet.hasCompleted) { - jobSets.remove(jobSet.time) - generator.onBatchCompletion(jobSet.time) - logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( - jobSet.totalDelay / 1000.0, jobSet.time.toString, - jobSet.processingDelay / 1000.0 - )) - listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo())) + private def handleJobCompletion(job: Job) { + job.result match { + case Success(_) => + val jobSet = jobSets.get(job.time) + jobSet.handleJobCompletion(job) + logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) + if (jobSet.hasCompleted) { + jobSets.remove(jobSet.time) + jobGenerator.onBatchCompletion(jobSet.time) + logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( + jobSet.totalDelay / 1000.0, jobSet.time.toString, + jobSet.processingDelay / 1000.0 + )) + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo)) + } + case Failure(e) => + reportError("Error running job " + job, e) } } - private[streaming] - class JobHandler(job: Job) extends Runnable { + private def handleError(msg: String, e: Throwable) { + logError(msg, e) + ssc.waiter.notifyError(e) + } + + private class JobHandler(job: Job) extends Runnable { def run() { - beforeJobStart(job) - try { - job.run() - } catch { - case e: Exception => - logError("Running " + job + " failed", e) - } - afterJobEnd(job) + eventActor ! JobStarted(job) + job.run() + eventActor ! JobCompleted(job) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 57268674ea..fcf303aee6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.streaming.Time /** Class representing a set of Jobs @@ -27,25 +27,25 @@ private[streaming] case class JobSet(time: Time, jobs: Seq[Job]) { private val incompleteJobs = new HashSet[Job]() - var submissionTime = System.currentTimeMillis() // when this jobset was submitted - var processingStartTime = -1L // when the first job of this jobset started processing - var processingEndTime = -1L // when the last job of this jobset finished processing + private val submissionTime = System.currentTimeMillis() // when this jobset was submitted + private var processingStartTime = -1L // when the first job of this jobset started processing + private var processingEndTime = -1L // when the last job of this jobset finished processing jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } incompleteJobs ++= jobs - def beforeJobStart(job: Job) { + def handleJobStart(job: Job) { if (processingStartTime < 0) processingStartTime = System.currentTimeMillis() } - def afterJobStop(job: Job) { + def handleJobCompletion(job: Job) { incompleteJobs -= job if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted() = (processingStartTime > 0) + def hasStarted = processingStartTime > 0 - def hasCompleted() = incompleteJobs.isEmpty + def hasCompleted = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) @@ -57,7 +57,7 @@ case class JobSet(time: Time, jobs: Seq[Job]) { processingEndTime - time.milliseconds } - def toBatchInfo(): BatchInfo = { + def toBatchInfo: BatchInfo = { new BatchInfo( time, submissionTime, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index 75f7244643..0d9733fa69 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -19,8 +19,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkException, Logging, SparkEnv} import org.apache.spark.SparkContext._ import scala.collection.mutable.HashMap @@ -32,6 +31,7 @@ import akka.pattern.ask import akka.dispatch._ import org.apache.spark.storage.BlockId import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.util.AkkaUtils private[streaming] sealed trait NetworkInputTrackerMessage private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage @@ -39,33 +39,47 @@ private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], m private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage /** - * This class manages the execution of the receivers of NetworkInputDStreams. + * This class manages the execution of the receivers of NetworkInputDStreams. Instance of + * this class must be created after all input streams have been added and StreamingContext.start() + * has been called because it needs the final set of input streams at the time of instantiation. */ private[streaming] -class NetworkInputTracker( - @transient ssc: StreamingContext, - @transient networkInputStreams: Array[NetworkInputDStream[_]]) - extends Logging { +class NetworkInputTracker(ssc: StreamingContext) extends Logging { + val networkInputStreams = ssc.graph.getNetworkInputStreams() val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] val receivedBlockIds = new HashMap[Int, Queue[BlockId]] - val timeout = 5000.milliseconds + val timeout = AkkaUtils.askTimeout(ssc.conf) + + // actor is created when generator starts. + // This not being null means the tracker has been started and not stopped + var actor: ActorRef = null var currentTime: Time = null /** Start the actor and receiver execution thread. */ def start() { - ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") - receiverExecutor.start() + if (actor != null) { + throw new SparkException("NetworkInputTracker already started") + } + + if (!networkInputStreams.isEmpty) { + actor = ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + logInfo("NetworkInputTracker started") + } } /** Stop the receiver execution thread. */ def stop() { - // TODO: stop the actor as well - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() + if (!networkInputStreams.isEmpty && actor != null) { + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + ssc.env.actorSystem.stop(actor) + logInfo("NetworkInputTracker stopped") + } } /** Return all the blocks received from a receiver. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 36225e190c..461ea35064 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -24,9 +24,10 @@ import org.apache.spark.util.Distribution sealed trait StreamingListenerEvent case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent - case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent +/** An event used in the listener to shutdown the listener daemon thread. */ +private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent /** * A listener interface for receiving information about an ongoing streaming diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 110a20f282..3063cf10a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -31,7 +31,7 @@ private[spark] class StreamingListenerBus() extends Logging { private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY) private var queueFullErrorMessageLogged = false - new Thread("StreamingListenerBus") { + val listenerThread = new Thread("StreamingListenerBus") { setDaemon(true) override def run() { while (true) { @@ -41,11 +41,18 @@ private[spark] class StreamingListenerBus() extends Logging { listeners.foreach(_.onBatchStarted(batchStarted)) case batchCompleted: StreamingListenerBatchCompleted => listeners.foreach(_.onBatchCompleted(batchCompleted)) + case StreamingListenerShutdown => + // Get out of the while loop and shutdown the daemon thread + return case _ => } } } - }.start() + } + + def start() { + listenerThread.start() + } def addListener(listener: StreamingListener) { listeners += listener @@ -54,9 +61,9 @@ private[spark] class StreamingListenerBus() extends Logging { def post(event: StreamingListenerEvent) { val eventAdded = eventQueue.offer(event) if (!eventAdded && !queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") + logError("Dropping StreamingListenerEvent because no remaining room in event queue. " + + "This likely means one of the StreamingListeners is too slow and cannot keep up with the " + + "rate at which events are being started by the scheduler.") queueFullErrorMessageLogged = true } } @@ -68,7 +75,7 @@ private[spark] class StreamingListenerBus() extends Logging { */ def waitUntilEmpty(timeoutMillis: Int): Boolean = { val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty()) { + while (!eventQueue.isEmpty) { if (System.currentTimeMillis > finishTime) { return false } @@ -76,6 +83,8 @@ private[spark] class StreamingListenerBus() extends Logging { * add overhead in the general case. */ Thread.sleep(10) } - return true + true } + + def stop(): Unit = post(StreamingListenerShutdown) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index f67bb2f6ac..c3a849d276 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -66,7 +66,7 @@ class SystemClock() extends Clock { } Thread.sleep(sleepTime) } - return -1 + -1 } } @@ -96,6 +96,6 @@ class ManualClock() extends Clock { this.wait(100) } } - return currentTime() + currentTime() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 4e6ce6eabd..5b6c048a39 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -90,7 +90,7 @@ object RawTextHelper { } } } - return taken.toIterator + taken.toIterator } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index d644240405..559c247385 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -20,17 +20,7 @@ package org.apache.spark.streaming.util private[streaming] class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { - private val minPollTime = 25L - - private val pollTime = { - if (period / 10.0 > minPollTime) { - (period / 10.0).toLong - } else { - minPollTime - } - } - - private val thread = new Thread() { + private val thread = new Thread("RecurringTimer") { override def run() { loop } } @@ -66,7 +56,6 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => callback(nextTime) nextTime += period } - } catch { case e: InterruptedException => } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index ee6b433d1f..9406e0e20a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -375,15 +375,11 @@ class BasicOperationsSuite extends TestSuiteBase { } test("slice") { - val conf2 = new SparkConf() - .setMaster("local[2]") - .setAppName("BasicOperationsSuite") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - val ssc = new StreamingContext(new SparkContext(conf2), Seconds(1)) + val ssc = new StreamingContext(conf, Seconds(1)) val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) val stream = new TestInputStream[Int](ssc, input, 2) ssc.registerInputStream(stream) - stream.foreach(_ => {}) // Dummy output stream + stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) def getInputFromSlice(fromMillis: Long, toMillis: Long) = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 65e7e5d469..67ce5bc566 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -28,6 +28,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.FileInputDStream import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf /** * This test suites tests the checkpointing functionality of DStreams - @@ -142,6 +144,26 @@ class CheckpointSuite extends TestSuiteBase { ssc = null } + // This tests whether spark conf persists through checkpoints, and certain + // configs gets scrubbed + test("persistence of conf through checkpoints") { + val key = "spark.mykey" + val value = "myvalue" + System.setProperty(key, value) + ssc = new StreamingContext(master, framework, batchDuration) + val cp = new Checkpoint(ssc, Time(1000)) + assert(!cp.sparkConf.contains("spark.driver.host")) + assert(!cp.sparkConf.contains("spark.driver.port")) + assert(!cp.sparkConf.contains("spark.hostPort")) + assert(cp.sparkConf.get(key) === value) + ssc.stop() + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + assert(!newCp.sparkConf.contains("spark.driver.host")) + assert(!newCp.sparkConf.contains("spark.driver.port")) + assert(!newCp.sparkConf.contains("spark.hostPort")) + assert(newCp.sparkConf.get(key) === value) + } + // This tests whether the systm can recover from a master failure with simple // non-stateful operations. This assumes as reliable, replayable input diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala new file mode 100644 index 0000000000..a477d200c9 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{FunSuite, BeforeAndAfter} +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ +import org.apache.spark.{SparkException, SparkConf, SparkContext} +import org.apache.spark.util.{Utils, MetadataCleaner} + +class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { + + val master = "local[2]" + val appName = this.getClass.getSimpleName + val batchDuration = Seconds(1) + val sparkHome = "someDir" + val envPair = "key" -> "value" + val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100 + + var sc: SparkContext = null + var ssc: StreamingContext = null + + before { + System.clearProperty("spark.cleaner.ttl") + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + if (sc != null) { + sc.stop() + sc = null + } + } + + test("from no conf constructor") { + ssc = new StreamingContext(master, appName, batchDuration) + assert(ssc.sparkContext.conf.get("spark.master") === master) + assert(ssc.sparkContext.conf.get("spark.app.name") === appName) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from no conf + spark home") { + ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil) + assert(ssc.conf.get("spark.home") === sparkHome) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from no conf + spark home + env") { + ssc = new StreamingContext(master, appName, batchDuration, + sparkHome, Nil, Map(envPair)) + assert(ssc.conf.getExecutorEnv.exists(_ == envPair)) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from conf without ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + ssc = new StreamingContext(myConf, batchDuration) + assert(MetadataCleaner.getDelaySeconds(ssc.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from conf with ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl) + } + + test("from existing SparkContext without ttl set") { + sc = new SparkContext(master, appName) + val exception = intercept[SparkException] { + ssc = new StreamingContext(sc, batchDuration) + } + assert(exception.getMessage.contains("ttl")) + } + + test("from existing SparkContext with ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl) + } + + test("from checkpoint") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + val ssc1 = new StreamingContext(myConf, batchDuration) + val cp = new Checkpoint(ssc1, Time(1000)) + assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl) + ssc1.stop() + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl) + ssc = new StreamingContext(null, cp, null) + assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl) + } + + test("start multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + + ssc.start() + intercept[SparkException] { + ssc.start() + } + } + + test("stop multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + ssc.start() + ssc.stop() + ssc.stop() + ssc = null + } + + test("stop only streaming context") { + ssc = new StreamingContext(master, appName, batchDuration) + sc = ssc.sparkContext + addInputStream(ssc).register + ssc.start() + ssc.stop(false) + ssc = null + assert(sc.makeRDD(1 to 100).collect().size === 100) + ssc = new StreamingContext(sc, batchDuration) + } + + test("awaitTermination") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register + + // test whether start() blocks indefinitely or not + failAfter(2000 millis) { + ssc.start() + } + + // test whether waitForStop() exits after give amount of time + failAfter(1000 millis) { + ssc.awaitTermination(500) + } + + // test whether waitForStop() does not exit if not time is given + val exception = intercept[Exception] { + failAfter(1000 millis) { + ssc.awaitTermination() + throw new Exception("Did not wait for stop") + } + } + assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + + // test whether wait exits if context is stopped + failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown + new Thread() { + override def run { + Thread.sleep(500) + ssc.stop() + } + }.start() + ssc.awaitTermination() + } + } + + test("awaitTermination with error in task") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => { throw new TestException("error in map task"); x}) + .foreach(_.count) + + val exception = intercept[Exception] { + ssc.start() + ssc.awaitTermination(5000) + } + assert(exception.getMessage.contains("map task"), "Expected exception not thrown") + } + + test("awaitTermination with error in job generation") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + + inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register + val exception = intercept[TestException] { + ssc.start() + ssc.awaitTermination(5000) + } + assert(exception.getMessage.contains("transform"), "Expected exception not thrown") + } + + def addInputStream(s: StreamingContext): DStream[Int] = { + val input = (1 to 100).map(i => (1 to i)) + val inputStream = new TestInputStream(s, input, 1) + s.registerInputStream(inputStream) + inputStream + } +} + +class TestException(msg: String) extends Exception(msg)
\ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 1979a0cedb..9b2bb57e77 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -137,7 +137,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val conf = new SparkConf() .setMaster(master) .setAppName(framework) - .set("spark.cleaner.ttl", "3600") + .set("spark.cleaner.ttl", StreamingContext.DEFAULT_CLEANER_TTL.toString) // Default before function for any streaming test suite. Override this // if you want to add your stuff to "before" (i.e., don't call before { } ) @@ -272,10 +272,11 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val startTime = System.currentTimeMillis() while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - Thread.sleep(10) + ssc.awaitTermination(50) } val timeTaken = System.currentTimeMillis() - startTime - + logInfo("Output generated in " + timeTaken + " milliseconds") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") |