From e95ff45b53bf995d89f1825b9581cc18a083a438 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 13 Oct 2012 20:10:49 -0700 Subject: Implemented checkpointing of StreamingContext and DStream graph. --- .../main/scala/spark/streaming/Checkpoint.scala | 92 +++++++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 123 ++++++++++++++------- .../main/scala/spark/streaming/DStreamGraph.scala | 80 ++++++++++++++ .../scala/spark/streaming/FileInputDStream.scala | 59 ++++++---- .../spark/streaming/ReducedWindowedDStream.scala | 80 +++++++------- .../src/main/scala/spark/streaming/Scheduler.scala | 33 +++--- .../main/scala/spark/streaming/StateDStream.scala | 20 ++-- .../scala/spark/streaming/StreamingContext.scala | 109 ++++++++++++------ .../examples/FileStreamWithCheckpoint.scala | 76 +++++++++++++ .../scala/spark/streaming/examples/Grep2.scala | 2 +- .../spark/streaming/examples/WordCount2.scala | 2 +- .../scala/spark/streaming/examples/WordMax2.scala | 2 +- .../spark/streaming/util/RecurringTimer.scala | 19 +++- 13 files changed, 534 insertions(+), 163 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/Checkpoint.scala create mode 100644 streaming/src/main/scala/spark/streaming/DStreamGraph.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala new file mode 100644 index 0000000000..3bd8fd5a27 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark.Utils + +import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.conf.Configuration + +import java.io.{ObjectInputStream, ObjectOutputStream} + +class Checkpoint(@transient ssc: StreamingContext) extends Serializable { + val master = ssc.sc.master + val frameworkName = ssc.sc.frameworkName + val sparkHome = ssc.sc.sparkHome + val jars = ssc.sc.jars + val graph = ssc.graph + val batchDuration = ssc.batchDuration + val checkpointFile = ssc.checkpointFile + val checkpointInterval = ssc.checkpointInterval + + def saveToFile(file: String) { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(this) + oos.close() + fs.close() + } + + def toBytes(): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } +} + +object Checkpoint { + + def loadFromFile(file: String): Checkpoint = { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (!fs.exists(path)) { + throw new Exception("Could not read checkpoint file " + path) + } + val fis = fs.open(path) + val ois = new ObjectInputStream(fis) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp + } + + def fromBytes(bytes: Array[Byte]): Checkpoint = { + Utils.deserialize[Checkpoint](bytes) + } + + /*def toBytes(ssc: StreamingContext): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } + + + def saveContext(ssc: StreamingContext, file: String) { + val cp = new Checkpoint(ssc) + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(cp) + oos.close() + fs.close() + } + + def loadContext(file: String): StreamingContext = { + loadCheckpoint(file).createNewContext() + } + */ +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 7e8098c346..78e4c57647 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -2,20 +2,19 @@ package spark.streaming import spark.streaming.StreamingContext._ -import spark.RDD -import spark.UnionRDD -import spark.Logging +import spark._ import spark.SparkContext._ import spark.storage.StorageLevel -import spark.Partitioner import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import scala.Some -abstract class DStream[T: ClassManifest] (@transient val ssc: StreamingContext) -extends Logging with Serializable { +abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) +extends Serializable with Logging { initLogging() @@ -41,10 +40,10 @@ extends Logging with Serializable { */ // Variable to store the RDDs generated earlier in time - @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] () + protected val generatedRDDs = new HashMap[Time, RDD[T]] () // Variable to be set to the first time seen by the DStream (effective time zero) - protected[streaming] var zeroTime: Time = null + protected var zeroTime: Time = null // Variable to specify storage level protected var storageLevel: StorageLevel = StorageLevel.NONE @@ -53,6 +52,9 @@ extends Logging with Serializable { protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint protected var checkpointInterval: Time = null + // Reference to whole DStream graph, so that checkpointing process can lock it + protected var graph: DStreamGraph = null + // Change this RDD's storage level def persist( storageLevel: StorageLevel, @@ -77,7 +79,7 @@ extends Logging with Serializable { // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() - def isInitialized = (zeroTime != null) + def isInitialized() = (zeroTime != null) /** * This method initializes the DStream by setting the "zero" time, based on which @@ -85,15 +87,33 @@ extends Logging with Serializable { * its parent DStreams. */ protected[streaming] def initialize(time: Time) { - if (zeroTime == null) { - zeroTime = time + if (zeroTime != null) { + throw new Exception("ZeroTime is already initialized, cannot initialize it again") } + zeroTime = time logInfo(this + " initialized") dependencies.foreach(_.initialize(zeroTime)) } + protected[streaming] def setContext(s: StreamingContext) { + if (ssc != null && ssc != s) { + throw new Exception("Context is already set, cannot set it again") + } + ssc = s + logInfo("Set context for " + this.getClass.getSimpleName) + dependencies.foreach(_.setContext(ssc)) + } + + protected[streaming] def setGraph(g: DStreamGraph) { + if (graph != null && graph != g) { + throw new Exception("Graph is already set, cannot set it again") + } + graph = g + dependencies.foreach(_.setGraph(graph)) + } + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - protected def isTimeValid (time: Time): Boolean = { + protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { @@ -158,13 +178,42 @@ extends Logging with Serializable { } } + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + println(this.getClass().getSimpleName + ".writeObject used") + if (graph != null) { + graph.synchronized { + if (graph.checkpointInProgress) { + oos.defaultWriteObject() + } else { + val msg = "Object of " + this.getClass.getName + " is being serialized " + + " possibly as a part of closure of an RDD operation. This is because " + + " the DStream object is being referred to from within the closure. " + + " Please rewrite the RDD operation inside this DStream to avoid this. " + + " This has been enforced to avoid bloating of Spark tasks " + + " with unnecessary objects." + throw new java.io.NotSerializableException(msg) + } + } + } else { + throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.") + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + } + /** * -------------- * DStream operations * -------------- */ - - def map[U: ClassManifest](mapFunc: T => U) = new MappedDStream(this, ssc.sc.clean(mapFunc)) + def map[U: ClassManifest](mapFunc: T => U) = { + new MappedDStream(this, ssc.sc.clean(mapFunc)) + } def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) @@ -262,19 +311,15 @@ extends Logging with Serializable { // Get all the RDDs between fromTime to toTime (both included) def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) - - while (time >= zeroTime && time >= fromTime) { getOrCompute(time) match { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not get old reduced RDD for time " + time) + case None => //throw new Exception("Could not get RDD for time " + time) } time -= slideTime } - rdds.toSeq } @@ -284,12 +329,16 @@ extends Logging with Serializable { } -abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) - extends DStream[T](ssc) { +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { override def dependencies = List() - override def slideTime = ssc.batchDuration + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.batchDuration == null) throw new Exception("ssc.batchDuration is null") + ssc.batchDuration + } def start() @@ -302,7 +351,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) */ class MappedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], mapFunc: T => U ) extends DStream[U](parent.ssc) { @@ -321,7 +370,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( */ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { @@ -340,7 +389,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( */ class FilteredDStream[T: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { @@ -359,7 +408,7 @@ class FilteredDStream[T: ClassManifest]( */ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], mapPartFunc: Iterator[T] => Iterator[U] ) extends DStream[U](parent.ssc) { @@ -377,7 +426,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ -class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) +class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -395,7 +444,7 @@ class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) */ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - @transient parent: DStream[(K,V)], + parent: DStream[(K,V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, @@ -420,7 +469,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ -class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) +class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) extends DStream[T](parents(0).ssc) { if (parents.length == 0) { @@ -459,7 +508,7 @@ class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) */ class PerElementForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: T => Unit ) extends DStream[Unit](parent.ssc) { @@ -490,7 +539,7 @@ class PerElementForEachDStream[T: ClassManifest] ( */ class PerRDDForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -518,15 +567,15 @@ class PerRDDForEachDStream[T: ClassManifest] ( */ class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Time = parent.slideTime - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) } +} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala new file mode 100644 index 0000000000..67859e0131 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -0,0 +1,80 @@ +package spark.streaming + +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import collection.mutable.ArrayBuffer + +final class DStreamGraph extends Serializable { + + private val inputStreams = new ArrayBuffer[InputDStream[_]]() + private val outputStreams = new ArrayBuffer[DStream[_]]() + + private[streaming] var zeroTime: Time = null + private[streaming] var checkpointInProgress = false; + + def started() = (zeroTime != null) + + def start(time: Time) { + this.synchronized { + if (started) { + throw new Exception("DStream graph computation already started") + } + zeroTime = time + outputStreams.foreach(_.initialize(zeroTime)) + inputStreams.par.foreach(_.start()) + } + + } + + def stop() { + this.synchronized { + inputStreams.par.foreach(_.stop()) + } + } + + private[streaming] def setContext(ssc: StreamingContext) { + this.synchronized { + outputStreams.foreach(_.setContext(ssc)) + } + } + + def addInputStream(inputStream: InputDStream[_]) { + inputStream.setGraph(this) + inputStreams += inputStream + } + + def addOutputStream(outputStream: DStream[_]) { + outputStream.setGraph(this) + outputStreams += outputStream + } + + def getInputStreams() = inputStreams.toArray + + def getOutputStreams() = outputStreams.toArray + + def generateRDDs(time: Time): Seq[Job] = { + this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + } + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + this.synchronized { + checkpointInProgress = true + oos.defaultWriteObject() + checkpointInProgress = false + } + println("DStreamGraph.writeObject used") + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + this.synchronized { + checkpointInProgress = true + ois.defaultReadObject() + checkpointInProgress = false + } + println("DStreamGraph.readObject used") + } +} + diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 96a64f0018..29ae89616e 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -1,33 +1,45 @@ package spark.streaming -import spark.SparkContext import spark.RDD -import spark.BlockRDD import spark.UnionRDD -import spark.storage.StorageLevel -import spark.streaming._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import java.io.{ObjectInputStream, IOException} class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - ssc: StreamingContext, - directory: Path, + @transient ssc_ : StreamingContext, + directory: String, filter: PathFilter = FileInputDStream.defaultPathFilter, newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc) { - - val fs = directory.getFileSystem(new Configuration()) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + /* + @transient @noinline lazy val path = { + //if (directory == null) throw new Exception("directory is null") + //println(directory) + new Path(directory) + } + @transient lazy val fs = path.getFileSystem(new Configuration()) + */ + var lastModTime: Long = 0 - + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + override def start() { if (newFilesOnly) { lastModTime = System.currentTimeMillis() @@ -58,7 +70,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } - val newFiles = fs.listStatus(directory, newFilter) + val newFiles = fs.listStatus(path, newFilter) logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) if (newFiles.length > 0) { lastModTime = newFilter.latestModTime @@ -67,10 +79,19 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) Some(newRDD) } + /* + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + println("HERE HERE" + this.directory) + } + */ + } object FileInputDStream { - val defaultPathFilter = new PathFilter { + val defaultPathFilter = new PathFilter with Serializable { def accept(path: Path): Boolean = { val file = path.getName() if (file.startsWith(".") || file.endsWith("_tmp")) { diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index b0beaba94d..e161b5ba92 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -10,9 +10,10 @@ import spark.SparkContext._ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer +import collection.SeqProxy class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - @transient parent: DStream[(K, V)], + parent: DStream[(K, V)], reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, _windowTime: Time, @@ -46,6 +47,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( } override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc val currentTime = validTime val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) @@ -84,54 +87,47 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // Cogroup the reduced RDDs and merge the reduced values val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValuesFunc) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - Some(mergedValuesRDD) - } - - def mergeValues(numOldValues: Int, numNewValues: Int)(seqOfValues: Seq[Seq[V]]): V = { - - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - - // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - - if (seqOfValues(0).isEmpty) { + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found") + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } - // Reduce the new values - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - return newValues.reduce(reduceFunc) - } else { + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head + Some(mergedValuesRDD) + } - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - // println("old values = " + oldValues.map(_.toString).reduce(_ + " " + _)) - tempValue = invReduceFunc(tempValue, oldValues.reduce(reduceFunc)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - tempValue = reduceFunc(tempValue, newValues.reduce(reduceFunc)) - } - // println("final value = " + tempValue) - return tempValue - } - } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index d2e907378d..d62b7e7140 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -11,45 +11,44 @@ import scala.collection.mutable.HashMap sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage -class Scheduler( - ssc: StreamingContext, - inputStreams: Array[InputDStream[_]], - outputStreams: Array[DStream[_]]) +class Scheduler(ssc: StreamingContext) extends Logging { initLogging() + val graph = ssc.graph val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) - + + def start() { - val zeroTime = Time(timer.start()) - outputStreams.foreach(_.initialize(zeroTime)) - inputStreams.par.foreach(_.start()) + if (graph.started) { + timer.restart(graph.zeroTime.milliseconds) + } else { + val zeroTime = Time(timer.start()) + graph.start(zeroTime) + } logInfo("Scheduler started") } def stop() { timer.stop() - inputStreams.par.foreach(_.stop()) + graph.stop() logInfo("Scheduler stopped") } - def generateRDDs (time: Time) { + def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") logInfo("Generating RDDs for time " + time) - outputStreams.foreach(outputStream => { - outputStream.generateJob(time) match { - case Some(job) => submitJob(job) - case None => - } - } - ) + graph.generateRDDs(time).foreach(submitJob) logInfo("Generated RDDs for time " + time) + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + ssc.checkpoint() + } } def submitJob(job: Job) { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index c40f70c91d..d223f25dfc 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -7,6 +7,12 @@ import spark.MapPartitionsRDD import spark.SparkContext._ import spark.storage.StorageLevel + +class StateRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U], rememberPartitioner: Boolean) + extends MapPartitionsRDD[U, T](prev, f) { + override val partitioner = if (rememberPartitioner) prev.partitioner else None +} + class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( @transient parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], @@ -14,11 +20,6 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { - class SpecialMapPartitionsRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U]) - extends MapPartitionsRDD(prev, f) { - override val partitioner = if (rememberPartitioner) prev.partitioner else None - } - override def dependencies = List(parent) override def slideTime = parent.slideTime @@ -79,19 +80,18 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { val i = iterator.map(t => { (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) }) updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc) + val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } case None => { // If parent RDD does not exist, then return old state RDD - //logDebug("Generating state RDD for time " + validTime + " (no change)") return Some(prevStateRDD) } } @@ -107,12 +107,12 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the grouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => { + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) } val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc) + val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 12f3626680..1499ef4ea2 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -21,31 +21,70 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -class StreamingContext (@transient val sc: SparkContext) extends Logging { +class StreamingContext ( + sc_ : SparkContext, + cp_ : Checkpoint + ) extends Logging { + + def this(sparkContext: SparkContext) = this(sparkContext, null) def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = - this(new SparkContext(master, frameworkName, sparkHome, jars)) + this(new SparkContext(master, frameworkName, sparkHome, jars), null) + + def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + + def this(cp_ : Checkpoint) = this(null, cp_) initLogging() + if (sc_ == null && cp_ == null) { + throw new Exception("Streaming Context cannot be initilalized with " + + "both SparkContext and checkpoint as null") + } + + val isCheckpointPresent = (cp_ != null) + + val sc: SparkContext = { + if (isCheckpointPresent) { + new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars) + } else { + sc_ + } + } + val env = SparkEnv.get - - val inputStreams = new ArrayBuffer[InputDStream[_]]() - val outputStreams = new ArrayBuffer[DStream[_]]() + + val graph: DStreamGraph = { + if (isCheckpointPresent) { + + cp_.graph.setContext(this) + cp_.graph + } else { + new DStreamGraph() + } + } + val nextNetworkInputStreamId = new AtomicInteger(0) - var batchDuration: Time = null - var scheduler: Scheduler = null + var batchDuration: Time = if (isCheckpointPresent) cp_.batchDuration else null + var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null var networkInputTracker: NetworkInputTracker = null - var receiverJobThread: Thread = null - - def setBatchDuration(duration: Long) { - setBatchDuration(Time(duration)) - } - + var receiverJobThread: Thread = null + var scheduler: Scheduler = null + def setBatchDuration(duration: Time) { + if (batchDuration != null) { + throw new Exception("Batch duration alread set as " + batchDuration + + ". cannot set it again.") + } batchDuration = duration } + + def setCheckpointDetails(file: String, interval: Time) { + checkpointFile = file + checkpointInterval = interval + } private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() @@ -59,7 +98,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { converter: (InputStream) => Iterator[T] ): DStream[T] = { val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -69,7 +108,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -94,8 +133,8 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest ](directory: String): DStream[(K, V)] = { - val inputStream = new FileInputDStream[K, V, F](this, new Path(directory)) - inputStreams += inputStream + val inputStream = new FileInputDStream[K, V, F](this, directory) + graph.addInputStream(inputStream) inputStream } @@ -113,24 +152,31 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { defaultRDD: RDD[T] = null ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } - def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = { + def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] val inputStream = createQueueStream(queue, true, null) queue ++= iterator inputStream - } + } + + /** + * This function registers a InputDStream as an input stream that will be + * started (InputDStream.start() called) to get the input data streams. + */ + def registerInputStream(inputStream: InputDStream[_]) { + graph.addInputStream(inputStream) + } - /** * This function registers a DStream as an output stream that will be * computed every interval. */ - def registerOutputStream (outputStream: DStream[_]) { - outputStreams += outputStream + def registerOutputStream(outputStream: DStream[_]) { + graph.addOutputStream(outputStream) } /** @@ -143,13 +189,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { if (batchDuration < Milliseconds(100)) { logWarning("Batch duration of " + batchDuration + " is very low") } - if (inputStreams.size == 0) { - throw new Exception("No input streams created, so nothing to take input from") - } - if (outputStreams.size == 0) { + if (graph.getOutputStreams().size == 0) { throw new Exception("No output streams registered, so nothing to execute") } - } /** @@ -157,7 +199,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { */ def start() { verify() - val networkInputStreams = inputStreams.filter(s => s match { + val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray @@ -169,8 +211,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { } Thread.sleep(1000) - // Start the scheduler - scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray) + + // Start the scheduler + scheduler = new Scheduler(this) scheduler.start() } @@ -189,6 +232,10 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { logInfo("StreamingContext stopped") } + + def checkpoint() { + new Checkpoint(this).saveToFile(checkpointFile) + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala new file mode 100644 index 0000000000..c725035a8a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -0,0 +1,76 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +object FileStreamWithCheckpoint { + + def main(args: Array[String]) { + + if (args.size != 3) { + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") + System.exit(-1) + } + + val directory = new Path(args(1)) + val checkpointFile = args(2) + + val ssc: StreamingContext = { + + if (args(0) == "restart") { + + // Recreated streaming context from specified checkpoint file + new StreamingContext(checkpointFile) + + } else { + + // Create directory if it does not exist + val fs = directory.getFileSystem(new Configuration()) + if (!fs.exists(directory)) fs.mkdirs(directory) + + // Create new streaming context + val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") + ssc_.setBatchDuration(Seconds(1)) + ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + + // Setup the streaming computation + val inputStream = ssc_.createTextFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + + ssc_ + } + } + + // Start the stream computation + startFileWritingThread(directory.toString) + ssc.start() + } + + def startFileWritingThread(directory: String) { + + val fs = new Path(directory).getFileSystem(new Configuration()) + + val fileWritingThread = new Thread() { + override def run() { + val r = new scala.util.Random() + val text = "This is a sample text file with a random number " + while(true) { + val number = r.nextInt() + val file = new Path(directory, number.toString) + val fos = fs.create(file) + fos.writeChars(text + number) + fos.close() + println("Created text file " + file) + Thread.sleep(1000) + } + } + } + fileWritingThread.start() + } + +} diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala index 7237142c7c..b1faa65c17 100644 --- a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala @@ -50,7 +50,7 @@ object Grep2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) sentences.filter(_.contains("Culpepper")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index c22949d7b9..8390f4af94 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -93,7 +93,7 @@ object WordCount2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 3658cb302d..fc7567322b 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -50,7 +50,7 @@ object WordMax2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 5da9fa6ecc..7f19b26a79 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -17,12 +17,23 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } var nextTime = 0L - - def start(): Long = { - nextTime = (math.floor(clock.currentTime / period) + 1).toLong * period - thread.start() + + def start(startTime: Long): Long = { + nextTime = startTime + thread.start() nextTime } + + def start(): Long = { + val startTime = math.ceil(clock.currentTime / period).toLong * period + start(startTime) + } + + def restart(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime + start(newStartTime) + } def stop() { thread.interrupt() -- cgit v1.2.3