aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-10-13 20:10:49 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2012-10-13 20:10:49 -0700
commite95ff45b53bf995d89f1825b9581cc18a083a438 (patch)
treec7d34da30bed3bdcb9d264464104aeb795faaf08 /streaming
parent6d1fe0268530fe555fa065b8fcfa72d53c931db0 (diff)
downloadspark-e95ff45b53bf995d89f1825b9581cc18a083a438.tar.gz
spark-e95ff45b53bf995d89f1825b9581cc18a083a438.tar.bz2
spark-e95ff45b53bf995d89f1825b9581cc18a083a438.zip
Implemented checkpointing of StreamingContext and DStream graph.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala92
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala123
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala80
-rw-r--r--streaming/src/main/scala/spark/streaming/FileInputDStream.scala59
-rw-r--r--streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala80
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala33
-rw-r--r--streaming/src/main/scala/spark/streaming/StateDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala109
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala76
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/Grep2.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCount2.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordMax2.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala19
13 files changed, 534 insertions, 163 deletions
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
new file mode 100644
index 0000000000..3bd8fd5a27
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -0,0 +1,92 @@
+package spark.streaming
+
+import spark.Utils
+
+import org.apache.hadoop.fs.{FileUtil, Path}
+import org.apache.hadoop.conf.Configuration
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+
+class Checkpoint(@transient ssc: StreamingContext) extends Serializable {
+ val master = ssc.sc.master
+ val frameworkName = ssc.sc.frameworkName
+ val sparkHome = ssc.sc.sparkHome
+ val jars = ssc.sc.jars
+ val graph = ssc.graph
+ val batchDuration = ssc.batchDuration
+ val checkpointFile = ssc.checkpointFile
+ val checkpointInterval = ssc.checkpointInterval
+
+ def saveToFile(file: String) {
+ val path = new Path(file)
+ val conf = new Configuration()
+ val fs = path.getFileSystem(conf)
+ if (fs.exists(path)) {
+ val bkPath = new Path(path.getParent, path.getName + ".bk")
+ FileUtil.copy(fs, path, fs, bkPath, true, true, conf)
+ println("Moved existing checkpoint file to " + bkPath)
+ }
+ val fos = fs.create(path)
+ val oos = new ObjectOutputStream(fos)
+ oos.writeObject(this)
+ oos.close()
+ fs.close()
+ }
+
+ def toBytes(): Array[Byte] = {
+ val cp = new Checkpoint(ssc)
+ val bytes = Utils.serialize(cp)
+ bytes
+ }
+}
+
+object Checkpoint {
+
+ def loadFromFile(file: String): Checkpoint = {
+ val path = new Path(file)
+ val conf = new Configuration()
+ val fs = path.getFileSystem(conf)
+ if (!fs.exists(path)) {
+ throw new Exception("Could not read checkpoint file " + path)
+ }
+ val fis = fs.open(path)
+ val ois = new ObjectInputStream(fis)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp
+ }
+
+ def fromBytes(bytes: Array[Byte]): Checkpoint = {
+ Utils.deserialize[Checkpoint](bytes)
+ }
+
+ /*def toBytes(ssc: StreamingContext): Array[Byte] = {
+ val cp = new Checkpoint(ssc)
+ val bytes = Utils.serialize(cp)
+ bytes
+ }
+
+
+ def saveContext(ssc: StreamingContext, file: String) {
+ val cp = new Checkpoint(ssc)
+ val path = new Path(file)
+ val conf = new Configuration()
+ val fs = path.getFileSystem(conf)
+ if (fs.exists(path)) {
+ val bkPath = new Path(path.getParent, path.getName + ".bk")
+ FileUtil.copy(fs, path, fs, bkPath, true, true, conf)
+ println("Moved existing checkpoint file to " + bkPath)
+ }
+ val fos = fs.create(path)
+ val oos = new ObjectOutputStream(fos)
+ oos.writeObject(cp)
+ oos.close()
+ fs.close()
+ }
+
+ def loadContext(file: String): StreamingContext = {
+ loadCheckpoint(file).createNewContext()
+ }
+ */
+}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 7e8098c346..78e4c57647 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -2,20 +2,19 @@ package spark.streaming
import spark.streaming.StreamingContext._
-import spark.RDD
-import spark.UnionRDD
-import spark.Logging
+import spark._
import spark.SparkContext._
import spark.storage.StorageLevel
-import spark.Partitioner
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import java.util.concurrent.ArrayBlockingQueue
+import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
+import scala.Some
-abstract class DStream[T: ClassManifest] (@transient val ssc: StreamingContext)
-extends Logging with Serializable {
+abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
+extends Serializable with Logging {
initLogging()
@@ -41,10 +40,10 @@ extends Logging with Serializable {
*/
// Variable to store the RDDs generated earlier in time
- @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] ()
+ protected val generatedRDDs = new HashMap[Time, RDD[T]] ()
// Variable to be set to the first time seen by the DStream (effective time zero)
- protected[streaming] var zeroTime: Time = null
+ protected var zeroTime: Time = null
// Variable to specify storage level
protected var storageLevel: StorageLevel = StorageLevel.NONE
@@ -53,6 +52,9 @@ extends Logging with Serializable {
protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint
protected var checkpointInterval: Time = null
+ // Reference to whole DStream graph, so that checkpointing process can lock it
+ protected var graph: DStreamGraph = null
+
// Change this RDD's storage level
def persist(
storageLevel: StorageLevel,
@@ -77,7 +79,7 @@ extends Logging with Serializable {
// Turn on the default caching level for this RDD
def cache(): DStream[T] = persist()
- def isInitialized = (zeroTime != null)
+ def isInitialized() = (zeroTime != null)
/**
* This method initializes the DStream by setting the "zero" time, based on which
@@ -85,15 +87,33 @@ extends Logging with Serializable {
* its parent DStreams.
*/
protected[streaming] def initialize(time: Time) {
- if (zeroTime == null) {
- zeroTime = time
+ if (zeroTime != null) {
+ throw new Exception("ZeroTime is already initialized, cannot initialize it again")
}
+ zeroTime = time
logInfo(this + " initialized")
dependencies.foreach(_.initialize(zeroTime))
}
+ protected[streaming] def setContext(s: StreamingContext) {
+ if (ssc != null && ssc != s) {
+ throw new Exception("Context is already set, cannot set it again")
+ }
+ ssc = s
+ logInfo("Set context for " + this.getClass.getSimpleName)
+ dependencies.foreach(_.setContext(ssc))
+ }
+
+ protected[streaming] def setGraph(g: DStreamGraph) {
+ if (graph != null && graph != g) {
+ throw new Exception("Graph is already set, cannot set it again")
+ }
+ graph = g
+ dependencies.foreach(_.setGraph(graph))
+ }
+
/** This method checks whether the 'time' is valid wrt slideTime for generating RDD */
- protected def isTimeValid (time: Time): Boolean = {
+ protected def isTimeValid(time: Time): Boolean = {
if (!isInitialized) {
throw new Exception (this.toString + " has not been initialized")
} else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) {
@@ -158,13 +178,42 @@ extends Logging with Serializable {
}
}
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ println(this.getClass().getSimpleName + ".writeObject used")
+ if (graph != null) {
+ graph.synchronized {
+ if (graph.checkpointInProgress) {
+ oos.defaultWriteObject()
+ } else {
+ val msg = "Object of " + this.getClass.getName + " is being serialized " +
+ " possibly as a part of closure of an RDD operation. This is because " +
+ " the DStream object is being referred to from within the closure. " +
+ " Please rewrite the RDD operation inside this DStream to avoid this. " +
+ " This has been enforced to avoid bloating of Spark tasks " +
+ " with unnecessary objects."
+ throw new java.io.NotSerializableException(msg)
+ }
+ }
+ } else {
+ throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.")
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ println(this.getClass().getSimpleName + ".readObject used")
+ ois.defaultReadObject()
+ }
+
/**
* --------------
* DStream operations
* --------------
*/
-
- def map[U: ClassManifest](mapFunc: T => U) = new MappedDStream(this, ssc.sc.clean(mapFunc))
+ def map[U: ClassManifest](mapFunc: T => U) = {
+ new MappedDStream(this, ssc.sc.clean(mapFunc))
+ }
def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = {
new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc))
@@ -262,19 +311,15 @@ extends Logging with Serializable {
// Get all the RDDs between fromTime to toTime (both included)
def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = {
-
val rdds = new ArrayBuffer[RDD[T]]()
var time = toTime.floor(slideTime)
-
-
while (time >= zeroTime && time >= fromTime) {
getOrCompute(time) match {
case Some(rdd) => rdds += rdd
- case None => throw new Exception("Could not get old reduced RDD for time " + time)
+ case None => //throw new Exception("Could not get RDD for time " + time)
}
time -= slideTime
}
-
rdds.toSeq
}
@@ -284,12 +329,16 @@ extends Logging with Serializable {
}
-abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext)
- extends DStream[T](ssc) {
+abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext)
+ extends DStream[T](ssc_) {
override def dependencies = List()
- override def slideTime = ssc.batchDuration
+ override def slideTime = {
+ if (ssc == null) throw new Exception("ssc is null")
+ if (ssc.batchDuration == null) throw new Exception("ssc.batchDuration is null")
+ ssc.batchDuration
+ }
def start()
@@ -302,7 +351,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext)
*/
class MappedDStream[T: ClassManifest, U: ClassManifest] (
- @transient parent: DStream[T],
+ parent: DStream[T],
mapFunc: T => U
) extends DStream[U](parent.ssc) {
@@ -321,7 +370,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] (
*/
class FlatMappedDStream[T: ClassManifest, U: ClassManifest](
- @transient parent: DStream[T],
+ parent: DStream[T],
flatMapFunc: T => Traversable[U]
) extends DStream[U](parent.ssc) {
@@ -340,7 +389,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest](
*/
class FilteredDStream[T: ClassManifest](
- @transient parent: DStream[T],
+ parent: DStream[T],
filterFunc: T => Boolean
) extends DStream[T](parent.ssc) {
@@ -359,7 +408,7 @@ class FilteredDStream[T: ClassManifest](
*/
class MapPartitionedDStream[T: ClassManifest, U: ClassManifest](
- @transient parent: DStream[T],
+ parent: DStream[T],
mapPartFunc: Iterator[T] => Iterator[U]
) extends DStream[U](parent.ssc) {
@@ -377,7 +426,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest](
* TODO
*/
-class GlommedDStream[T: ClassManifest](@transient parent: DStream[T])
+class GlommedDStream[T: ClassManifest](parent: DStream[T])
extends DStream[Array[T]](parent.ssc) {
override def dependencies = List(parent)
@@ -395,7 +444,7 @@ class GlommedDStream[T: ClassManifest](@transient parent: DStream[T])
*/
class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
- @transient parent: DStream[(K,V)],
+ parent: DStream[(K,V)],
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiner: (C, C) => C,
@@ -420,7 +469,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
* TODO
*/
-class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]])
+class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]])
extends DStream[T](parents(0).ssc) {
if (parents.length == 0) {
@@ -459,7 +508,7 @@ class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]])
*/
class PerElementForEachDStream[T: ClassManifest] (
- @transient parent: DStream[T],
+ parent: DStream[T],
foreachFunc: T => Unit
) extends DStream[Unit](parent.ssc) {
@@ -490,7 +539,7 @@ class PerElementForEachDStream[T: ClassManifest] (
*/
class PerRDDForEachDStream[T: ClassManifest] (
- @transient parent: DStream[T],
+ parent: DStream[T],
foreachFunc: (RDD[T], Time) => Unit
) extends DStream[Unit](parent.ssc) {
@@ -518,15 +567,15 @@ class PerRDDForEachDStream[T: ClassManifest] (
*/
class TransformedDStream[T: ClassManifest, U: ClassManifest] (
- @transient parent: DStream[T],
+ parent: DStream[T],
transformFunc: (RDD[T], Time) => RDD[U]
) extends DStream[U](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies = List(parent)
- override def slideTime: Time = parent.slideTime
+ override def slideTime: Time = parent.slideTime
- override def compute(validTime: Time): Option[RDD[U]] = {
- parent.getOrCompute(validTime).map(transformFunc(_, validTime))
- }
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(transformFunc(_, validTime))
}
+}
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
new file mode 100644
index 0000000000..67859e0131
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -0,0 +1,80 @@
+package spark.streaming
+
+import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
+import collection.mutable.ArrayBuffer
+
+final class DStreamGraph extends Serializable {
+
+ private val inputStreams = new ArrayBuffer[InputDStream[_]]()
+ private val outputStreams = new ArrayBuffer[DStream[_]]()
+
+ private[streaming] var zeroTime: Time = null
+ private[streaming] var checkpointInProgress = false;
+
+ def started() = (zeroTime != null)
+
+ def start(time: Time) {
+ this.synchronized {
+ if (started) {
+ throw new Exception("DStream graph computation already started")
+ }
+ zeroTime = time
+ outputStreams.foreach(_.initialize(zeroTime))
+ inputStreams.par.foreach(_.start())
+ }
+
+ }
+
+ def stop() {
+ this.synchronized {
+ inputStreams.par.foreach(_.stop())
+ }
+ }
+
+ private[streaming] def setContext(ssc: StreamingContext) {
+ this.synchronized {
+ outputStreams.foreach(_.setContext(ssc))
+ }
+ }
+
+ def addInputStream(inputStream: InputDStream[_]) {
+ inputStream.setGraph(this)
+ inputStreams += inputStream
+ }
+
+ def addOutputStream(outputStream: DStream[_]) {
+ outputStream.setGraph(this)
+ outputStreams += outputStream
+ }
+
+ def getInputStreams() = inputStreams.toArray
+
+ def getOutputStreams() = outputStreams.toArray
+
+ def generateRDDs(time: Time): Seq[Job] = {
+ this.synchronized {
+ outputStreams.flatMap(outputStream => outputStream.generateJob(time))
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ this.synchronized {
+ checkpointInProgress = true
+ oos.defaultWriteObject()
+ checkpointInProgress = false
+ }
+ println("DStreamGraph.writeObject used")
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ this.synchronized {
+ checkpointInProgress = true
+ ois.defaultReadObject()
+ checkpointInProgress = false
+ }
+ println("DStreamGraph.readObject used")
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala
index 96a64f0018..29ae89616e 100644
--- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala
@@ -1,33 +1,45 @@
package spark.streaming
-import spark.SparkContext
import spark.RDD
-import spark.BlockRDD
import spark.UnionRDD
-import spark.storage.StorageLevel
-import spark.streaming._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import java.net.InetSocketAddress
-
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.fs.PathFilter
+import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import java.io.{ObjectInputStream, IOException}
class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest](
- ssc: StreamingContext,
- directory: Path,
+ @transient ssc_ : StreamingContext,
+ directory: String,
filter: PathFilter = FileInputDStream.defaultPathFilter,
newFilesOnly: Boolean = true)
- extends InputDStream[(K, V)](ssc) {
-
- val fs = directory.getFileSystem(new Configuration())
+ extends InputDStream[(K, V)](ssc_) {
+
+ @transient private var path_ : Path = null
+ @transient private var fs_ : FileSystem = null
+
+ /*
+ @transient @noinline lazy val path = {
+ //if (directory == null) throw new Exception("directory is null")
+ //println(directory)
+ new Path(directory)
+ }
+ @transient lazy val fs = path.getFileSystem(new Configuration())
+ */
+
var lastModTime: Long = 0
-
+
+ def path(): Path = {
+ if (path_ == null) path_ = new Path(directory)
+ path_
+ }
+
+ def fs(): FileSystem = {
+ if (fs_ == null) fs_ = path.getFileSystem(new Configuration())
+ fs_
+ }
+
override def start() {
if (newFilesOnly) {
lastModTime = System.currentTimeMillis()
@@ -58,7 +70,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
}
}
- val newFiles = fs.listStatus(directory, newFilter)
+ val newFiles = fs.listStatus(path, newFilter)
logInfo("New files: " + newFiles.map(_.getPath).mkString(", "))
if (newFiles.length > 0) {
lastModTime = newFilter.latestModTime
@@ -67,10 +79,19 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)))
Some(newRDD)
}
+ /*
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ println(this.getClass().getSimpleName + ".readObject used")
+ ois.defaultReadObject()
+ println("HERE HERE" + this.directory)
+ }
+ */
+
}
object FileInputDStream {
- val defaultPathFilter = new PathFilter {
+ val defaultPathFilter = new PathFilter with Serializable {
def accept(path: Path): Boolean = {
val file = path.getName()
if (file.startsWith(".") || file.endsWith("_tmp")) {
diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
index b0beaba94d..e161b5ba92 100644
--- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
@@ -10,9 +10,10 @@ import spark.SparkContext._
import spark.storage.StorageLevel
import scala.collection.mutable.ArrayBuffer
+import collection.SeqProxy
class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
- @transient parent: DStream[(K, V)],
+ parent: DStream[(K, V)],
reduceFunc: (V, V) => V,
invReduceFunc: (V, V) => V,
_windowTime: Time,
@@ -46,6 +47,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
}
override def compute(validTime: Time): Option[RDD[(K, V)]] = {
+ val reduceF = reduceFunc
+ val invReduceF = invReduceFunc
val currentTime = validTime
val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime)
@@ -84,54 +87,47 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
// Cogroup the reduced RDDs and merge the reduced values
val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner)
- val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _
- val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValuesFunc)
+ //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _
- Some(mergedValuesRDD)
- }
-
- def mergeValues(numOldValues: Int, numNewValues: Int)(seqOfValues: Seq[Seq[V]]): V = {
-
- if (seqOfValues.size != 1 + numOldValues + numNewValues) {
- throw new Exception("Unexpected number of sequences of reduced values")
- }
-
- // Getting reduced values "old time steps" that will be removed from current window
- val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head)
-
- // Getting reduced values "new time steps"
- val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
-
- if (seqOfValues(0).isEmpty) {
+ val numOldValues = oldRDDs.size
+ val numNewValues = newRDDs.size
- // If previous window's reduce value does not exist, then at least new values should exist
- if (newValues.isEmpty) {
- throw new Exception("Neither previous window has value for key, nor new values found")
+ val mergeValues = (seqOfValues: Seq[Seq[V]]) => {
+ if (seqOfValues.size != 1 + numOldValues + numNewValues) {
+ throw new Exception("Unexpected number of sequences of reduced values")
}
+ // Getting reduced values "old time steps" that will be removed from current window
+ val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head)
+ // Getting reduced values "new time steps"
+ val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
+ if (seqOfValues(0).isEmpty) {
+ // If previous window's reduce value does not exist, then at least new values should exist
+ if (newValues.isEmpty) {
+ throw new Exception("Neither previous window has value for key, nor new values found")
+ }
+ // Reduce the new values
+ newValues.reduce(reduceF) // return
+ } else {
+ // Get the previous window's reduced value
+ var tempValue = seqOfValues(0).head
+ // If old values exists, then inverse reduce then from previous value
+ if (!oldValues.isEmpty) {
+ tempValue = invReduceF(tempValue, oldValues.reduce(reduceF))
+ }
+ // If new values exists, then reduce them with previous value
+ if (!newValues.isEmpty) {
+ tempValue = reduceF(tempValue, newValues.reduce(reduceF))
+ }
+ tempValue // return
+ }
+ }
- // Reduce the new values
- // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _))
- return newValues.reduce(reduceFunc)
- } else {
+ val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues)
- // Get the previous window's reduced value
- var tempValue = seqOfValues(0).head
+ Some(mergedValuesRDD)
+ }
- // If old values exists, then inverse reduce then from previous value
- if (!oldValues.isEmpty) {
- // println("old values = " + oldValues.map(_.toString).reduce(_ + " " + _))
- tempValue = invReduceFunc(tempValue, oldValues.reduce(reduceFunc))
- }
- // If new values exists, then reduce them with previous value
- if (!newValues.isEmpty) {
- // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _))
- tempValue = reduceFunc(tempValue, newValues.reduce(reduceFunc))
- }
- // println("final value = " + tempValue)
- return tempValue
- }
- }
}
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
index d2e907378d..d62b7e7140 100644
--- a/streaming/src/main/scala/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -11,45 +11,44 @@ import scala.collection.mutable.HashMap
sealed trait SchedulerMessage
case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage
-class Scheduler(
- ssc: StreamingContext,
- inputStreams: Array[InputDStream[_]],
- outputStreams: Array[DStream[_]])
+class Scheduler(ssc: StreamingContext)
extends Logging {
initLogging()
+ val graph = ssc.graph
val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt
val jobManager = new JobManager(ssc, concurrentJobs)
val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock")
val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_))
-
+
+
def start() {
- val zeroTime = Time(timer.start())
- outputStreams.foreach(_.initialize(zeroTime))
- inputStreams.par.foreach(_.start())
+ if (graph.started) {
+ timer.restart(graph.zeroTime.milliseconds)
+ } else {
+ val zeroTime = Time(timer.start())
+ graph.start(zeroTime)
+ }
logInfo("Scheduler started")
}
def stop() {
timer.stop()
- inputStreams.par.foreach(_.stop())
+ graph.stop()
logInfo("Scheduler stopped")
}
- def generateRDDs (time: Time) {
+ def generateRDDs(time: Time) {
SparkEnv.set(ssc.env)
logInfo("\n-----------------------------------------------------\n")
logInfo("Generating RDDs for time " + time)
- outputStreams.foreach(outputStream => {
- outputStream.generateJob(time) match {
- case Some(job) => submitJob(job)
- case None =>
- }
- }
- )
+ graph.generateRDDs(time).foreach(submitJob)
logInfo("Generated RDDs for time " + time)
+ if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) {
+ ssc.checkpoint()
+ }
}
def submitJob(job: Job) {
diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala
index c40f70c91d..d223f25dfc 100644
--- a/streaming/src/main/scala/spark/streaming/StateDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala
@@ -7,6 +7,12 @@ import spark.MapPartitionsRDD
import spark.SparkContext._
import spark.storage.StorageLevel
+
+class StateRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U], rememberPartitioner: Boolean)
+ extends MapPartitionsRDD[U, T](prev, f) {
+ override val partitioner = if (rememberPartitioner) prev.partitioner else None
+}
+
class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest](
@transient parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)],
@@ -14,11 +20,6 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife
rememberPartitioner: Boolean
) extends DStream[(K, S)](parent.ssc) {
- class SpecialMapPartitionsRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U])
- extends MapPartitionsRDD(prev, f) {
- override val partitioner = if (rememberPartitioner) prev.partitioner else None
- }
-
override def dependencies = List(parent)
override def slideTime = parent.slideTime
@@ -79,19 +80,18 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
- val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => {
+ val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => {
val i = iterator.map(t => {
(t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S]))
})
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
- val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc)
+ val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner)
//logDebug("Generating state RDD for time " + validTime)
return Some(stateRDD)
}
case None => { // If parent RDD does not exist, then return old state RDD
- //logDebug("Generating state RDD for time " + validTime + " (no change)")
return Some(prevStateRDD)
}
}
@@ -107,12 +107,12 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
- val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => {
+ val finalFunc = (iterator: Iterator[(K, Seq[V])]) => {
updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S])))
}
val groupedRDD = parentRDD.groupByKey(partitioner)
- val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc)
+ val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner)
//logDebug("Generating state RDD for time " + validTime + " (first)")
return Some(sessionRDD)
}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 12f3626680..1499ef4ea2 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -21,31 +21,70 @@ import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
-class StreamingContext (@transient val sc: SparkContext) extends Logging {
+class StreamingContext (
+ sc_ : SparkContext,
+ cp_ : Checkpoint
+ ) extends Logging {
+
+ def this(sparkContext: SparkContext) = this(sparkContext, null)
def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) =
- this(new SparkContext(master, frameworkName, sparkHome, jars))
+ this(new SparkContext(master, frameworkName, sparkHome, jars), null)
+
+ def this(file: String) = this(null, Checkpoint.loadFromFile(file))
+
+ def this(cp_ : Checkpoint) = this(null, cp_)
initLogging()
+ if (sc_ == null && cp_ == null) {
+ throw new Exception("Streaming Context cannot be initilalized with " +
+ "both SparkContext and checkpoint as null")
+ }
+
+ val isCheckpointPresent = (cp_ != null)
+
+ val sc: SparkContext = {
+ if (isCheckpointPresent) {
+ new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars)
+ } else {
+ sc_
+ }
+ }
+
val env = SparkEnv.get
-
- val inputStreams = new ArrayBuffer[InputDStream[_]]()
- val outputStreams = new ArrayBuffer[DStream[_]]()
+
+ val graph: DStreamGraph = {
+ if (isCheckpointPresent) {
+
+ cp_.graph.setContext(this)
+ cp_.graph
+ } else {
+ new DStreamGraph()
+ }
+ }
+
val nextNetworkInputStreamId = new AtomicInteger(0)
- var batchDuration: Time = null
- var scheduler: Scheduler = null
+ var batchDuration: Time = if (isCheckpointPresent) cp_.batchDuration else null
+ var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null
+ var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null
var networkInputTracker: NetworkInputTracker = null
- var receiverJobThread: Thread = null
-
- def setBatchDuration(duration: Long) {
- setBatchDuration(Time(duration))
- }
-
+ var receiverJobThread: Thread = null
+ var scheduler: Scheduler = null
+
def setBatchDuration(duration: Time) {
+ if (batchDuration != null) {
+ throw new Exception("Batch duration alread set as " + batchDuration +
+ ". cannot set it again.")
+ }
batchDuration = duration
}
+
+ def setCheckpointDetails(file: String, interval: Time) {
+ checkpointFile = file
+ checkpointInterval = interval
+ }
private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
@@ -59,7 +98,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
converter: (InputStream) => Iterator[T]
): DStream[T] = {
val inputStream = new ObjectInputDStream[T](this, hostname, port, converter)
- inputStreams += inputStream
+ graph.addInputStream(inputStream)
inputStream
}
@@ -69,7 +108,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_2
): DStream[T] = {
val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel)
- inputStreams += inputStream
+ graph.addInputStream(inputStream)
inputStream
}
@@ -94,8 +133,8 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
V: ClassManifest,
F <: NewInputFormat[K, V]: ClassManifest
](directory: String): DStream[(K, V)] = {
- val inputStream = new FileInputDStream[K, V, F](this, new Path(directory))
- inputStreams += inputStream
+ val inputStream = new FileInputDStream[K, V, F](this, directory)
+ graph.addInputStream(inputStream)
inputStream
}
@@ -113,24 +152,31 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
defaultRDD: RDD[T] = null
): DStream[T] = {
val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD)
- inputStreams += inputStream
+ graph.addInputStream(inputStream)
inputStream
}
- def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = {
+ def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = {
val queue = new Queue[RDD[T]]
val inputStream = createQueueStream(queue, true, null)
queue ++= iterator
inputStream
- }
+ }
+
+ /**
+ * This function registers a InputDStream as an input stream that will be
+ * started (InputDStream.start() called) to get the input data streams.
+ */
+ def registerInputStream(inputStream: InputDStream[_]) {
+ graph.addInputStream(inputStream)
+ }
-
/**
* This function registers a DStream as an output stream that will be
* computed every interval.
*/
- def registerOutputStream (outputStream: DStream[_]) {
- outputStreams += outputStream
+ def registerOutputStream(outputStream: DStream[_]) {
+ graph.addOutputStream(outputStream)
}
/**
@@ -143,13 +189,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
if (batchDuration < Milliseconds(100)) {
logWarning("Batch duration of " + batchDuration + " is very low")
}
- if (inputStreams.size == 0) {
- throw new Exception("No input streams created, so nothing to take input from")
- }
- if (outputStreams.size == 0) {
+ if (graph.getOutputStreams().size == 0) {
throw new Exception("No output streams registered, so nothing to execute")
}
-
}
/**
@@ -157,7 +199,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
*/
def start() {
verify()
- val networkInputStreams = inputStreams.filter(s => s match {
+ val networkInputStreams = graph.getInputStreams().filter(s => s match {
case n: NetworkInputDStream[_] => true
case _ => false
}).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
@@ -169,8 +211,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
}
Thread.sleep(1000)
- // Start the scheduler
- scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray)
+
+ // Start the scheduler
+ scheduler = new Scheduler(this)
scheduler.start()
}
@@ -189,6 +232,10 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging {
logInfo("StreamingContext stopped")
}
+
+ def checkpoint() {
+ new Checkpoint(this).saveToFile(checkpointFile)
+ }
}
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
new file mode 100644
index 0000000000..c725035a8a
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
@@ -0,0 +1,76 @@
+package spark.streaming.examples
+
+import spark.streaming._
+import spark.streaming.StreamingContext._
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
+
+object FileStreamWithCheckpoint {
+
+ def main(args: Array[String]) {
+
+ if (args.size != 3) {
+ println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>")
+ println("FileStreamWithCheckpoint restart <directory> <checkpoint file>")
+ System.exit(-1)
+ }
+
+ val directory = new Path(args(1))
+ val checkpointFile = args(2)
+
+ val ssc: StreamingContext = {
+
+ if (args(0) == "restart") {
+
+ // Recreated streaming context from specified checkpoint file
+ new StreamingContext(checkpointFile)
+
+ } else {
+
+ // Create directory if it does not exist
+ val fs = directory.getFileSystem(new Configuration())
+ if (!fs.exists(directory)) fs.mkdirs(directory)
+
+ // Create new streaming context
+ val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint")
+ ssc_.setBatchDuration(Seconds(1))
+ ssc_.setCheckpointDetails(checkpointFile, Seconds(1))
+
+ // Setup the streaming computation
+ val inputStream = ssc_.createTextFileStream(directory.toString)
+ val words = inputStream.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+
+ ssc_
+ }
+ }
+
+ // Start the stream computation
+ startFileWritingThread(directory.toString)
+ ssc.start()
+ }
+
+ def startFileWritingThread(directory: String) {
+
+ val fs = new Path(directory).getFileSystem(new Configuration())
+
+ val fileWritingThread = new Thread() {
+ override def run() {
+ val r = new scala.util.Random()
+ val text = "This is a sample text file with a random number "
+ while(true) {
+ val number = r.nextInt()
+ val file = new Path(directory, number.toString)
+ val fos = fs.create(file)
+ fos.writeChars(text + number)
+ fos.close()
+ println("Created text file " + file)
+ Thread.sleep(1000)
+ }
+ }
+ }
+ fileWritingThread.start()
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala
index 7237142c7c..b1faa65c17 100644
--- a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala
@@ -50,7 +50,7 @@ object Grep2 {
println("Data count: " + data.count())
val sentences = new ConstantInputDStream(ssc, data)
- ssc.inputStreams += sentences
+ ssc.registerInputStream(sentences)
sentences.filter(_.contains("Culpepper")).count().foreachRDD(r =>
println("Grep count: " + r.collect().mkString))
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
index c22949d7b9..8390f4af94 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
@@ -93,7 +93,7 @@ object WordCount2 {
println("Data count: " + data.count())
val sentences = new ConstantInputDStream(ssc, data)
- ssc.inputStreams += sentences
+ ssc.registerInputStream(sentences)
import WordCount2_ExtraFunctions._
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
index 3658cb302d..fc7567322b 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
@@ -50,7 +50,7 @@ object WordMax2 {
println("Data count: " + data.count())
val sentences = new ConstantInputDStream(ssc, data)
- ssc.inputStreams += sentences
+ ssc.registerInputStream(sentences)
import WordCount2_ExtraFunctions._
diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
index 5da9fa6ecc..7f19b26a79 100644
--- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
@@ -17,12 +17,23 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) =>
}
var nextTime = 0L
-
- def start(): Long = {
- nextTime = (math.floor(clock.currentTime / period) + 1).toLong * period
- thread.start()
+
+ def start(startTime: Long): Long = {
+ nextTime = startTime
+ thread.start()
nextTime
}
+
+ def start(): Long = {
+ val startTime = math.ceil(clock.currentTime / period).toLong * period
+ start(startTime)
+ }
+
+ def restart(originalStartTime: Long): Long = {
+ val gap = clock.currentTime - originalStartTime
+ val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime
+ start(newStartTime)
+ }
def stop() {
thread.interrupt()