aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/SparkContext.scala11
-rw-r--r--core/src/main/scala/spark/Utils.scala2
-rw-r--r--core/src/main/scala/spark/util/RateLimitedOutputStream.scala56
-rw-r--r--project/SparkBuild.scala8
-rwxr-xr-xrun2
-rw-r--r--sentences.txt3
-rwxr-xr-xstartTrigger.sh3
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala96
-rw-r--r--streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala38
-rw-r--r--streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala18
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala660
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala124
-rw-r--r--streaming/src/main/scala/spark/streaming/FileInputDStream.scala87
-rw-r--r--streaming/src/main/scala/spark/streaming/Interval.scala50
-rw-r--r--streaming/src/main/scala/spark/streaming/Job.scala22
-rw-r--r--streaming/src/main/scala/spark/streaming/JobManager.scala32
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala151
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala118
-rw-r--r--streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala236
-rw-r--r--streaming/src/main/scala/spark/streaming/QueueInputDStream.scala40
-rw-r--r--streaming/src/main/scala/spark/streaming/RawInputDStream.scala83
-rw-r--r--streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala143
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala69
-rw-r--r--streaming/src/main/scala/spark/streaming/SocketInputDStream.scala173
-rw-r--r--streaming/src/main/scala/spark/streaming/StateDStream.scala130
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala229
-rw-r--r--streaming/src/main/scala/spark/streaming/Time.scala56
-rw-r--r--streaming/src/main/scala/spark/streaming/WindowedDStream.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/CountRaw.scala32
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStream.scala47
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala76
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/Grep2.scala64
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala33
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/QueueStream.scala41
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala95
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCount2.scala115
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala26
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala25
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala51
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordMax2.scala73
-rw-r--r--streaming/src/main/scala/spark/streaming/util/Clock.scala84
-rw-r--r--streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala157
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RawTextSender.scala60
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala73
-rw-r--r--streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala67
-rw-r--r--streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala92
-rw-r--r--streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala23
-rw-r--r--streaming/src/main/scala/spark/streaming/util/TestGenerator.scala107
-rw-r--r--streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala119
-rw-r--r--streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala244
-rw-r--r--streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala39
-rw-r--r--streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala421
-rw-r--r--streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala374
-rw-r--r--streaming/src/test/resources/log4j.properties8
-rw-r--r--streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala213
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala53
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala114
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala216
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala188
59 files changed, 6001 insertions, 5 deletions
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index d26cccbfe1..0d37075ef3 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -58,10 +58,10 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
* @param environment Environment variables to set on worker nodes.
*/
class SparkContext(
- master: String,
- jobName: String,
+ val master: String,
+ val jobName: String,
val sparkHome: String,
- jars: Seq[String],
+ val jars: Seq[String],
environment: Map[String, String])
extends Logging {
@@ -595,6 +595,11 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
+
+ // TODO: temporary hack for using HDFS as input in streaing
+ var inputFile: String = null
+ var idealPartitions: Int = 1
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 567c4b1475..1bdde25896 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -247,7 +247,7 @@ private object Utils extends Logging {
* millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
- return " " + (System.currentTimeMillis - startTimeMs) + " ms "
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms"
}
/**
diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
new file mode 100644
index 0000000000..d11ed163ce
--- /dev/null
+++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
@@ -0,0 +1,56 @@
+package spark.util
+
+import java.io.OutputStream
+
+class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream {
+ var lastSyncTime = System.nanoTime()
+ var bytesWrittenSinceSync: Long = 0
+
+ override def write(b: Int) {
+ waitToWrite(1)
+ out.write(b)
+ }
+
+ override def write(bytes: Array[Byte]) {
+ write(bytes, 0, bytes.length)
+ }
+
+ override def write(bytes: Array[Byte], offset: Int, length: Int) {
+ val CHUNK_SIZE = 8192
+ var pos = 0
+ while (pos < length) {
+ val writeSize = math.min(length - pos, CHUNK_SIZE)
+ waitToWrite(writeSize)
+ out.write(bytes, offset + pos, writeSize)
+ pos += writeSize
+ }
+ }
+
+ def waitToWrite(numBytes: Int) {
+ while (true) {
+ val now = System.nanoTime()
+ val elapsed = math.max(now - lastSyncTime, 1)
+ val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9)
+ if (rate < bytesPerSec) {
+ // It's okay to write; just update some variables and return
+ bytesWrittenSinceSync += numBytes
+ if (now > lastSyncTime + (1e10).toLong) {
+ // Ten seconds have passed since lastSyncTime; let's resync
+ lastSyncTime = now
+ bytesWrittenSinceSync = numBytes
+ }
+ return
+ } else {
+ Thread.sleep(5)
+ }
+ }
+ }
+
+ override def flush() {
+ out.flush()
+ }
+
+ override def close() {
+ out.close()
+ }
+} \ No newline at end of file
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 2f67bb9921..688bb16a03 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -17,7 +17,7 @@ object SparkBuild extends Build {
//val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1"
//val HADOOP_MAJOR_VERSION = "2"
- lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel)
+ lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming)
lazy val core = Project("core", file("core"), settings = coreSettings)
@@ -27,6 +27,8 @@ object SparkBuild extends Build {
lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn (core)
+ lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core)
+
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -153,6 +155,10 @@ object SparkBuild extends Build {
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
+ def streamingSettings = sharedSettings ++ Seq(
+ name := "spark-streaming"
+ ) ++ assemblySettings ++ extraAssemblySettings
+
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
diff --git a/run b/run
index 83175e84de..a363599cf0 100755
--- a/run
+++ b/run
@@ -63,6 +63,7 @@ CORE_DIR="$FWDIR/core"
REPL_DIR="$FWDIR/repl"
EXAMPLES_DIR="$FWDIR/examples"
BAGEL_DIR="$FWDIR/bagel"
+STREAMING_DIR="$FWDIR/streaming"
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH"
@@ -74,6 +75,7 @@ fi
CLASSPATH+=":$CORE_DIR/src/main/resources"
CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
for jar in `find $FWDIR/lib_managed/jars -name '*jar'`; do
CLASSPATH+=":$jar"
done
diff --git a/sentences.txt b/sentences.txt
new file mode 100644
index 0000000000..fedf96c66e
--- /dev/null
+++ b/sentences.txt
@@ -0,0 +1,3 @@
+Hello world!
+What's up?
+There is no cow level
diff --git a/startTrigger.sh b/startTrigger.sh
new file mode 100755
index 0000000000..373dbda93e
--- /dev/null
+++ b/startTrigger.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+./run spark.streaming.SentenceGenerator localhost 7078 sentences.txt 1
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
new file mode 100644
index 0000000000..83a43d15cb
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -0,0 +1,96 @@
+package spark.streaming
+
+import spark.Utils
+
+import org.apache.hadoop.fs.{FileUtil, Path}
+import org.apache.hadoop.conf.Configuration
+
+import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream}
+
+
+class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable {
+ val master = ssc.sc.master
+ val framework = ssc.sc.jobName
+ val sparkHome = ssc.sc.sparkHome
+ val jars = ssc.sc.jars
+ val graph = ssc.graph
+ val checkpointFile = ssc.checkpointFile
+ val checkpointInterval = ssc.checkpointInterval
+
+ validate()
+
+ def validate() {
+ assert(master != null, "Checkpoint.master is null")
+ assert(framework != null, "Checkpoint.framework is null")
+ assert(graph != null, "Checkpoint.graph is null")
+ assert(checkpointTime != null, "Checkpoint.checkpointTime is null")
+ }
+
+ def saveToFile(file: String = checkpointFile) {
+ val path = new Path(file)
+ val conf = new Configuration()
+ val fs = path.getFileSystem(conf)
+ if (fs.exists(path)) {
+ val bkPath = new Path(path.getParent, path.getName + ".bk")
+ FileUtil.copy(fs, path, fs, bkPath, true, true, conf)
+ //logInfo("Moved existing checkpoint file to " + bkPath)
+ }
+ val fos = fs.create(path)
+ val oos = new ObjectOutputStream(fos)
+ oos.writeObject(this)
+ oos.close()
+ fs.close()
+ }
+
+ def toBytes(): Array[Byte] = {
+ val bytes = Utils.serialize(this)
+ bytes
+ }
+}
+
+object Checkpoint {
+
+ def loadFromFile(file: String): Checkpoint = {
+ try {
+ val path = new Path(file)
+ val conf = new Configuration()
+ val fs = path.getFileSystem(conf)
+ if (!fs.exists(path)) {
+ throw new Exception("Checkpoint file '" + file + "' does not exist")
+ }
+ val fis = fs.open(path)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ cp
+ } catch {
+ case e: Exception =>
+ e.printStackTrace()
+ throw new Exception("Could not load checkpoint file '" + file + "'", e)
+ }
+ }
+
+ def fromBytes(bytes: Array[Byte]): Checkpoint = {
+ val cp = Utils.deserialize[Checkpoint](bytes)
+ cp.validate()
+ cp
+ }
+}
+
+class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) {
+ override def resolveClass(desc: ObjectStreamClass): Class[_] = {
+ try {
+ return loader.loadClass(desc.getName())
+ } catch {
+ case e: Exception =>
+ }
+ return super.resolveClass(desc)
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala
new file mode 100644
index 0000000000..61d088eddb
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala
@@ -0,0 +1,38 @@
+package spark.streaming
+
+import spark.{RDD, Partitioner}
+import spark.rdd.CoGroupedRDD
+
+class CoGroupedDStream[K : ClassManifest](
+ parents: Seq[DStream[(_, _)]],
+ partitioner: Partitioner
+ ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) {
+
+ if (parents.length == 0) {
+ throw new IllegalArgumentException("Empty array of parents")
+ }
+
+ if (parents.map(_.ssc).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different StreamingContexts")
+ }
+
+ if (parents.map(_.slideTime).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different slide times")
+ }
+
+ override def dependencies = parents.toList
+
+ override def slideTime = parents.head.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = {
+ val part = partitioner
+ val rdds = parents.flatMap(_.getOrCompute(validTime))
+ if (rdds.size > 0) {
+ val q = new CoGroupedRDD[K](rdds, part)
+ Some(q)
+ } else {
+ None
+ }
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala
new file mode 100644
index 0000000000..80150708fd
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala
@@ -0,0 +1,18 @@
+package spark.streaming
+
+import spark.RDD
+
+/**
+ * An input stream that always returns the same RDD on each timestep. Useful for testing.
+ */
+class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T])
+ extends InputDStream[T](ssc_) {
+
+ override def start() {}
+
+ override def stop() {}
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ Some(rdd)
+ }
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
new file mode 100644
index 0000000000..12d7ba97ea
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -0,0 +1,660 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+
+import spark._
+import spark.SparkContext._
+import spark.rdd._
+import spark.storage.StorageLevel
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import java.util.concurrent.ArrayBlockingQueue
+import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
+import scala.Some
+
+abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
+extends Serializable with Logging {
+
+ initLogging()
+
+ /**
+ * ----------------------------------------------
+ * Methods that must be implemented by subclasses
+ * ----------------------------------------------
+ */
+
+ // Time by which the window slides in this DStream
+ def slideTime: Time
+
+ // List of parent DStreams on which this DStream depends on
+ def dependencies: List[DStream[_]]
+
+ // Key method that computes RDD for a valid time
+ def compute (validTime: Time): Option[RDD[T]]
+
+ /**
+ * ---------------------------------------
+ * Other general fields and methods of DStream
+ * ---------------------------------------
+ */
+
+ // RDDs generated, marked as protected[streaming] so that testsuites can access it
+ protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] ()
+
+ // Time zero for the DStream
+ protected var zeroTime: Time = null
+
+ // Duration for which the DStream will remember each RDD created
+ protected var rememberDuration: Time = null
+
+ // Storage level of the RDDs in the stream
+ protected var storageLevel: StorageLevel = StorageLevel.NONE
+
+ // Checkpoint level and checkpoint interval
+ protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint
+ protected var checkpointInterval: Time = null
+
+ // Reference to whole DStream graph
+ protected var graph: DStreamGraph = null
+
+ def isInitialized = (zeroTime != null)
+
+ // Duration for which the DStream requires its parent DStream to remember each RDD created
+ def parentRememberDuration = rememberDuration
+
+ // Change this RDD's storage level
+ def persist(
+ storageLevel: StorageLevel,
+ checkpointLevel: StorageLevel,
+ checkpointInterval: Time): DStream[T] = {
+ if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) {
+ // TODO: not sure this is necessary for DStreams
+ throw new UnsupportedOperationException(
+ "Cannot change storage level of an DStream after it was already assigned a level")
+ }
+ this.storageLevel = storageLevel
+ this.checkpointLevel = checkpointLevel
+ this.checkpointInterval = checkpointInterval
+ this
+ }
+
+ // Set caching level for the RDDs created by this DStream
+ def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null)
+
+ def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY)
+
+ // Turn on the default caching level for this RDD
+ def cache(): DStream[T] = persist()
+
+ /**
+ * This method initializes the DStream by setting the "zero" time, based on which
+ * the validity of future times is calculated. This method also recursively initializes
+ * its parent DStreams.
+ */
+ protected[streaming] def initialize(time: Time) {
+ if (zeroTime != null && zeroTime != time) {
+ throw new Exception("ZeroTime is already initialized to " + zeroTime
+ + ", cannot initialize it again to " + time)
+ }
+ zeroTime = time
+ dependencies.foreach(_.initialize(zeroTime))
+ logInfo("Initialized " + this)
+ }
+
+ protected[streaming] def setContext(s: StreamingContext) {
+ if (ssc != null && ssc != s) {
+ throw new Exception("Context is already set in " + this + ", cannot set it again")
+ }
+ ssc = s
+ logInfo("Set context for " + this)
+ dependencies.foreach(_.setContext(ssc))
+ }
+
+ protected[streaming] def setGraph(g: DStreamGraph) {
+ if (graph != null && graph != g) {
+ throw new Exception("Graph is already set in " + this + ", cannot set it again")
+ }
+ graph = g
+ dependencies.foreach(_.setGraph(graph))
+ }
+
+ protected[streaming] def setRememberDuration(duration: Time = slideTime) {
+ if (duration == null) {
+ throw new Exception("Duration for remembering RDDs cannot be set to null for " + this)
+ } else if (rememberDuration != null && duration < rememberDuration) {
+ logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration
+ + " to " + duration + " for " + this)
+ } else {
+ rememberDuration = duration
+ dependencies.foreach(_.setRememberDuration(parentRememberDuration))
+ logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
+ }
+ }
+
+ /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */
+ protected def isTimeValid(time: Time): Boolean = {
+ if (!isInitialized) {
+ throw new Exception (this + " has not been initialized")
+ } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) {
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * This method either retrieves a precomputed RDD of this DStream,
+ * or computes the RDD (if the time is valid)
+ */
+ def getOrCompute(time: Time): Option[RDD[T]] = {
+ // If this DStream was not initialized (i.e., zeroTime not set), then do it
+ // If RDD was already generated, then retrieve it from HashMap
+ generatedRDDs.get(time) match {
+
+ // If an RDD was already generated and is being reused, then
+ // probably all RDDs in this DStream will be reused and hence should be cached
+ case Some(oldRDD) => Some(oldRDD)
+
+ // if RDD was not generated, and if the time is valid
+ // (based on sliding time of this DStream), then generate the RDD
+ case None => {
+ if (isTimeValid(time)) {
+ compute(time) match {
+ case Some(newRDD) =>
+ if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
+ newRDD.persist(checkpointLevel)
+ logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
+ } else if (storageLevel != StorageLevel.NONE) {
+ newRDD.persist(storageLevel)
+ logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
+ }
+ generatedRDDs.put(time, newRDD)
+ Some(newRDD)
+ case None =>
+ None
+ }
+ } else {
+ None
+ }
+ }
+ }
+ }
+
+ /**
+ * This method generates a SparkStreaming job for the given time
+ * and may required to be overriden by subclasses
+ */
+ def generateJob(time: Time): Option[Job] = {
+ getOrCompute(time) match {
+ case Some(rdd) => {
+ val jobFunc = () => {
+ val emptyFunc = { (iterator: Iterator[T]) => {} }
+ ssc.sc.runJob(rdd, emptyFunc)
+ }
+ Some(new Job(time, jobFunc))
+ }
+ case None => None
+ }
+ }
+
+ def forgetOldRDDs(time: Time) {
+ val keys = generatedRDDs.keys
+ var numForgotten = 0
+ keys.foreach(t => {
+ if (t <= (time - rememberDuration)) {
+ generatedRDDs.remove(t)
+ numForgotten += 1
+ //logInfo("Forgot RDD of time " + t + " from " + this)
+ }
+ })
+ logInfo("Forgot " + numForgotten + " RDDs from " + this)
+ dependencies.foreach(_.forgetOldRDDs(time))
+ }
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ logDebug(this.getClass().getSimpleName + ".writeObject used")
+ if (graph != null) {
+ graph.synchronized {
+ if (graph.checkpointInProgress) {
+ oos.defaultWriteObject()
+ } else {
+ val msg = "Object of " + this.getClass.getName + " is being serialized " +
+ " possibly as a part of closure of an RDD operation. This is because " +
+ " the DStream object is being referred to from within the closure. " +
+ " Please rewrite the RDD operation inside this DStream to avoid this. " +
+ " This has been enforced to avoid bloating of Spark tasks " +
+ " with unnecessary objects."
+ throw new java.io.NotSerializableException(msg)
+ }
+ }
+ } else {
+ throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.")
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ logDebug(this.getClass().getSimpleName + ".readObject used")
+ ois.defaultReadObject()
+ }
+
+ /**
+ * --------------
+ * DStream operations
+ * --------------
+ */
+ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = {
+ new MappedDStream(this, ssc.sc.clean(mapFunc))
+ }
+
+ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = {
+ new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc))
+ }
+
+ def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc)
+
+ def glom(): DStream[Array[T]] = new GlommedDStream(this)
+
+ def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = {
+ new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc))
+ }
+
+ def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2)
+
+ def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _)
+
+ def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2)
+
+ def foreach(foreachFunc: T => Unit) {
+ val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc))
+ ssc.registerOutputStream(newStream)
+ newStream
+ }
+
+ def foreachRDD(foreachFunc: RDD[T] => Unit) {
+ foreachRDD((r: RDD[T], t: Time) => foreachFunc(r))
+ }
+
+ def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) {
+ val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc))
+ ssc.registerOutputStream(newStream)
+ newStream
+ }
+
+ def transformRDD[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
+ transformRDD((r: RDD[T], t: Time) => transformFunc(r))
+ }
+
+ def transformRDD[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
+ new TransformedDStream(this, ssc.sc.clean(transformFunc))
+ }
+
+ def toBlockingQueue() = {
+ val queue = new ArrayBlockingQueue[RDD[T]](10000)
+ this.foreachRDD(rdd => {
+ queue.add(rdd)
+ })
+ queue
+ }
+
+ def print() {
+ def foreachFunc = (rdd: RDD[T], time: Time) => {
+ val first11 = rdd.take(11)
+ println ("-------------------------------------------")
+ println ("Time: " + time)
+ println ("-------------------------------------------")
+ first11.take(10).foreach(println)
+ if (first11.size > 10) println("...")
+ println()
+ }
+ val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc))
+ ssc.registerOutputStream(newStream)
+ }
+
+ def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime)
+
+ def window(windowTime: Time, slideTime: Time): DStream[T] = {
+ new WindowedDStream(this, windowTime, slideTime)
+ }
+
+ def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime)
+
+ def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = {
+ this.window(windowTime, slideTime).reduce(reduceFunc)
+ }
+
+ def reduceByWindow(
+ reduceFunc: (T, T) => T,
+ invReduceFunc: (T, T) => T,
+ windowTime: Time,
+ slideTime: Time
+ ): DStream[T] = {
+ this.map(x => (1, x))
+ .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1)
+ .map(_._2)
+ }
+
+ def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = {
+ def add(v1: Int, v2: Int) = (v1 + v2)
+ def subtract(v1: Int, v2: Int) = (v1 - v2)
+ this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime)
+ }
+
+ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that))
+
+ def slice(interval: Interval): Seq[RDD[T]] = {
+ slice(interval.beginTime, interval.endTime)
+ }
+
+ // Get all the RDDs between fromTime to toTime (both included)
+ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = {
+ val rdds = new ArrayBuffer[RDD[T]]()
+ var time = toTime.floor(slideTime)
+ while (time >= zeroTime && time >= fromTime) {
+ getOrCompute(time) match {
+ case Some(rdd) => rdds += rdd
+ case None => //throw new Exception("Could not get RDD for time " + time)
+ }
+ time -= slideTime
+ }
+ rdds.toSeq
+ }
+
+ def register() {
+ ssc.registerOutputStream(this)
+ }
+}
+
+
+abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext)
+ extends DStream[T](ssc_) {
+
+ override def dependencies = List()
+
+ override def slideTime = {
+ if (ssc == null) throw new Exception("ssc is null")
+ if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null")
+ ssc.graph.batchDuration
+ }
+
+ def start()
+
+ def stop()
+}
+
+
+/**
+ * TODO
+ */
+
+class MappedDStream[T: ClassManifest, U: ClassManifest] (
+ parent: DStream[T],
+ mapFunc: T => U
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.map[U](mapFunc))
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class FlatMappedDStream[T: ClassManifest, U: ClassManifest](
+ parent: DStream[T],
+ flatMapFunc: T => Traversable[U]
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc))
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class FilteredDStream[T: ClassManifest](
+ parent: DStream[T],
+ filterFunc: T => Boolean
+ ) extends DStream[T](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ parent.getOrCompute(validTime).map(_.filter(filterFunc))
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class MapPartitionedDStream[T: ClassManifest, U: ClassManifest](
+ parent: DStream[T],
+ mapPartFunc: Iterator[T] => Iterator[U]
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc))
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class GlommedDStream[T: ClassManifest](parent: DStream[T])
+ extends DStream[Array[T]](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[Array[T]]] = {
+ parent.getOrCompute(validTime).map(_.glom())
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
+ parent: DStream[(K,V)],
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiner: (C, C) => C,
+ partitioner: Partitioner
+ ) extends DStream [(K,C)] (parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K,C)]] = {
+ parent.getOrCompute(validTime) match {
+ case Some(rdd) =>
+ Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner))
+ case None => None
+ }
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ mapValueFunc: V => U
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc))
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ flatMapValueFunc: V => TraversableOnce[U]
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc))
+ }
+}
+
+
+
+/**
+ * TODO
+ */
+
+class UnionDStream[T: ClassManifest](parents: Array[DStream[T]])
+ extends DStream[T](parents.head.ssc) {
+
+ if (parents.length == 0) {
+ throw new IllegalArgumentException("Empty array of parents")
+ }
+
+ if (parents.map(_.ssc).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different StreamingContexts")
+ }
+
+ if (parents.map(_.slideTime).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different slide times")
+ }
+
+ override def dependencies = parents.toList
+
+ override def slideTime: Time = parents.head.slideTime
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val rdds = new ArrayBuffer[RDD[T]]()
+ parents.map(_.getOrCompute(validTime)).foreach(_ match {
+ case Some(rdd) => rdds += rdd
+ case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime)
+ })
+ if (rdds.size > 0) {
+ Some(new UnionRDD(ssc.sc, rdds))
+ } else {
+ None
+ }
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class PerElementForEachDStream[T: ClassManifest] (
+ parent: DStream[T],
+ foreachFunc: T => Unit
+ ) extends DStream[Unit](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[Unit]] = None
+
+ override def generateJob(time: Time): Option[Job] = {
+ parent.getOrCompute(time) match {
+ case Some(rdd) =>
+ val jobFunc = () => {
+ val sparkJobFunc = {
+ (iterator: Iterator[T]) => iterator.foreach(foreachFunc)
+ }
+ ssc.sc.runJob(rdd, sparkJobFunc)
+ }
+ Some(new Job(time, jobFunc))
+ case None => None
+ }
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class PerRDDForEachDStream[T: ClassManifest] (
+ parent: DStream[T],
+ foreachFunc: (RDD[T], Time) => Unit
+ ) extends DStream[Unit](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[Unit]] = None
+
+ override def generateJob(time: Time): Option[Job] = {
+ parent.getOrCompute(time) match {
+ case Some(rdd) =>
+ val jobFunc = () => {
+ foreachFunc(rdd, time)
+ }
+ Some(new Job(time, jobFunc))
+ case None => None
+ }
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class TransformedDStream[T: ClassManifest, U: ClassManifest] (
+ parent: DStream[T],
+ transformFunc: (RDD[T], Time) => RDD[U]
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(transformFunc(_, validTime))
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
new file mode 100644
index 0000000000..ac44d7a2a6
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -0,0 +1,124 @@
+package spark.streaming
+
+import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
+import collection.mutable.ArrayBuffer
+import spark.Logging
+
+final class DStreamGraph extends Serializable with Logging {
+ initLogging()
+
+ private val inputStreams = new ArrayBuffer[InputDStream[_]]()
+ private val outputStreams = new ArrayBuffer[DStream[_]]()
+
+ private[streaming] var zeroTime: Time = null
+ private[streaming] var batchDuration: Time = null
+ private[streaming] var rememberDuration: Time = null
+ private[streaming] var checkpointInProgress = false
+
+ def start(time: Time) {
+ this.synchronized {
+ if (zeroTime != null) {
+ throw new Exception("DStream graph computation already started")
+ }
+ zeroTime = time
+ outputStreams.foreach(_.initialize(zeroTime))
+ outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values
+ if (rememberDuration != null) {
+ // if custom rememberDuration has been provided, set the rememberDuration
+ outputStreams.foreach(_.setRememberDuration(rememberDuration))
+ }
+ inputStreams.par.foreach(_.start())
+ }
+ }
+
+ def stop() {
+ this.synchronized {
+ inputStreams.par.foreach(_.stop())
+ }
+ }
+
+ private[streaming] def setContext(ssc: StreamingContext) {
+ this.synchronized {
+ outputStreams.foreach(_.setContext(ssc))
+ }
+ }
+
+ def setBatchDuration(duration: Time) {
+ this.synchronized {
+ if (batchDuration != null) {
+ throw new Exception("Batch duration already set as " + batchDuration +
+ ". cannot set it again.")
+ }
+ }
+ batchDuration = duration
+ }
+
+ def setRememberDuration(duration: Time) {
+ this.synchronized {
+ if (rememberDuration != null) {
+ throw new Exception("Batch duration already set as " + batchDuration +
+ ". cannot set it again.")
+ }
+ }
+ rememberDuration = duration
+ }
+
+ def addInputStream(inputStream: InputDStream[_]) {
+ this.synchronized {
+ inputStream.setGraph(this)
+ inputStreams += inputStream
+ }
+ }
+
+ def addOutputStream(outputStream: DStream[_]) {
+ this.synchronized {
+ outputStream.setGraph(this)
+ outputStreams += outputStream
+ }
+ }
+
+ def getInputStreams() = inputStreams.toArray
+
+ def getOutputStreams() = outputStreams.toArray
+
+ def generateRDDs(time: Time): Seq[Job] = {
+ this.synchronized {
+ outputStreams.flatMap(outputStream => outputStream.generateJob(time))
+ }
+ }
+
+ def forgetOldRDDs(time: Time) {
+ this.synchronized {
+ outputStreams.foreach(_.forgetOldRDDs(time))
+ }
+ }
+
+ def validate() {
+ this.synchronized {
+ assert(batchDuration != null, "Batch duration has not been set")
+ assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low")
+ assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute")
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ this.synchronized {
+ logDebug("DStreamGraph.writeObject used")
+ checkpointInProgress = true
+ oos.defaultWriteObject()
+ checkpointInProgress = false
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ this.synchronized {
+ logDebug("DStreamGraph.readObject used")
+ checkpointInProgress = true
+ ois.defaultReadObject()
+ checkpointInProgress = false
+ }
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala
new file mode 100644
index 0000000000..537ec88047
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala
@@ -0,0 +1,87 @@
+package spark.streaming
+
+import spark.RDD
+import spark.rdd.UnionRDD
+
+import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import java.io.{ObjectInputStream, IOException}
+
+
+class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest](
+ @transient ssc_ : StreamingContext,
+ directory: String,
+ filter: PathFilter = FileInputDStream.defaultPathFilter,
+ newFilesOnly: Boolean = true)
+ extends InputDStream[(K, V)](ssc_) {
+
+ @transient private var path_ : Path = null
+ @transient private var fs_ : FileSystem = null
+
+ var lastModTime: Long = 0
+
+ def path(): Path = {
+ if (path_ == null) path_ = new Path(directory)
+ path_
+ }
+
+ def fs(): FileSystem = {
+ if (fs_ == null) fs_ = path.getFileSystem(new Configuration())
+ fs_
+ }
+
+ override def start() {
+ if (newFilesOnly) {
+ lastModTime = System.currentTimeMillis()
+ } else {
+ lastModTime = 0
+ }
+ }
+
+ override def stop() { }
+
+ override def compute(validTime: Time): Option[RDD[(K, V)]] = {
+ val newFilter = new PathFilter() {
+ var latestModTime = 0L
+
+ def accept(path: Path): Boolean = {
+ if (!filter.accept(path)) {
+ return false
+ } else {
+ val modTime = fs.getFileStatus(path).getModificationTime()
+ if (modTime <= lastModTime) {
+ return false
+ }
+ if (modTime > latestModTime) {
+ latestModTime = modTime
+ }
+ return true
+ }
+ }
+ }
+
+ val newFiles = fs.listStatus(path, newFilter)
+ logInfo("New files: " + newFiles.map(_.getPath).mkString(", "))
+ if (newFiles.length > 0) {
+ lastModTime = newFilter.latestModTime
+ }
+ val newRDD = new UnionRDD(ssc.sc, newFiles.map(
+ file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)))
+ Some(newRDD)
+ }
+}
+
+object FileInputDStream {
+ val defaultPathFilter = new PathFilter with Serializable {
+ def accept(path: Path): Boolean = {
+ val file = path.getName()
+ if (file.startsWith(".") || file.endsWith("_tmp")) {
+ return false
+ } else {
+ return true
+ }
+ }
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala
new file mode 100644
index 0000000000..ffb7725ac9
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Interval.scala
@@ -0,0 +1,50 @@
+package spark.streaming
+
+case class Interval(beginTime: Time, endTime: Time) {
+ def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs))
+
+ def duration(): Time = endTime - beginTime
+
+ def + (time: Time): Interval = {
+ new Interval(beginTime + time, endTime + time)
+ }
+
+ def - (time: Time): Interval = {
+ new Interval(beginTime - time, endTime - time)
+ }
+
+ def < (that: Interval): Boolean = {
+ if (this.duration != that.duration) {
+ throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]")
+ }
+ this.endTime < that.endTime
+ }
+
+ def <= (that: Interval) = (this < that || this == that)
+
+ def > (that: Interval) = !(this <= that)
+
+ def >= (that: Interval) = !(this < that)
+
+ def next(): Interval = {
+ this + (endTime - beginTime)
+ }
+
+ def isZero = (beginTime.isZero && endTime.isZero)
+
+ def toFormattedString = beginTime.toFormattedString + "-" + endTime.toFormattedString
+
+ override def toString = "[" + beginTime + ", " + endTime + "]"
+}
+
+object Interval {
+ def zero() = new Interval (Time.zero, Time.zero)
+
+ def currentInterval(intervalDuration: Time): Interval = {
+ val time = Time(System.currentTimeMillis)
+ val intervalBegin = time.floor(intervalDuration)
+ Interval(intervalBegin, intervalBegin + intervalDuration)
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala
new file mode 100644
index 0000000000..0bcb6fd8dc
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Job.scala
@@ -0,0 +1,22 @@
+package spark.streaming
+
+import java.util.concurrent.atomic.AtomicLong
+
+class Job(val time: Time, func: () => _) {
+ val id = Job.getNewId()
+ def run(): Long = {
+ val startTime = System.currentTimeMillis
+ func()
+ val stopTime = System.currentTimeMillis
+ (stopTime - startTime)
+ }
+
+ override def toString = "streaming job " + id + " @ " + time
+}
+
+object Job {
+ val id = new AtomicLong(0)
+
+ def getNewId() = id.getAndIncrement()
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala
new file mode 100644
index 0000000000..9bf9251519
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/JobManager.scala
@@ -0,0 +1,32 @@
+package spark.streaming
+
+import spark.Logging
+import spark.SparkEnv
+import java.util.concurrent.Executors
+
+
+class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
+
+ class JobHandler(ssc: StreamingContext, job: Job) extends Runnable {
+ def run() {
+ SparkEnv.set(ssc.env)
+ try {
+ val timeTaken = job.run()
+ logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format(
+ (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0))
+ } catch {
+ case e: Exception =>
+ logError("Running " + job + " failed", e)
+ }
+ }
+ }
+
+ initLogging()
+
+ val jobExecutor = Executors.newFixedThreadPool(numThreads)
+
+ def runJob(job: Job) {
+ jobExecutor.execute(new JobHandler(ssc, job))
+ logInfo("Added " + job + " to queue")
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
new file mode 100644
index 0000000000..f3f4c3ab13
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
@@ -0,0 +1,151 @@
+package spark.streaming
+
+import scala.collection.mutable.ArrayBuffer
+
+import spark.{Logging, SparkEnv, RDD}
+import spark.rdd.BlockRDD
+import spark.storage.StorageLevel
+
+import java.nio.ByteBuffer
+
+import akka.actor.{Props, Actor}
+import akka.pattern.ask
+import akka.dispatch.Await
+import akka.util.duration._
+
+abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext)
+ extends InputDStream[T](ssc_) {
+
+ // This is an unique identifier that is used to match the network receiver with the
+ // corresponding network input stream.
+ val id = ssc.getNewNetworkStreamId()
+
+ /**
+ * This method creates the receiver object that will be sent to the workers
+ * to receive data. This method needs to defined by any specific implementation
+ * of a NetworkInputDStream.
+ */
+ def createReceiver(): NetworkReceiver[T]
+
+ // Nothing to start or stop as both taken care of by the NetworkInputTracker.
+ def start() {}
+
+ def stop() {}
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
+ Some(new BlockRDD[T](ssc.sc, blockIds))
+ }
+}
+
+
+sealed trait NetworkReceiverMessage
+case class StopReceiver(msg: String) extends NetworkReceiverMessage
+case class ReportBlock(blockId: String) extends NetworkReceiverMessage
+case class ReportError(msg: String) extends NetworkReceiverMessage
+
+abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging {
+
+ initLogging()
+
+ lazy protected val env = SparkEnv.get
+
+ lazy protected val actor = env.actorSystem.actorOf(
+ Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId)
+
+ lazy protected val receivingThread = Thread.currentThread()
+
+ /** This method will be called to start receiving data. */
+ protected def onStart()
+
+ /** This method will be called to stop receiving data. */
+ protected def onStop()
+
+ /**
+ * This method starts the receiver. First is accesses all the lazy members to
+ * materialize them. Then it calls the user-defined onStart() method to start
+ * other threads, etc required to receiver the data.
+ */
+ def start() {
+ try {
+ // Access the lazy vals to materialize them
+ env
+ actor
+ receivingThread
+
+ // Call user-defined onStart()
+ onStart()
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Receiving thread interrupted")
+ //println("Receiving thread interrupted")
+ case e: Exception =>
+ stopOnError(e)
+ }
+ }
+
+ /**
+ * This method stops the receiver. First it interrupts the main receiving thread,
+ * that is, the thread that called receiver.start(). Then it calls the user-defined
+ * onStop() method to stop other threads and/or do cleanup.
+ */
+ def stop() {
+ receivingThread.interrupt()
+ onStop()
+ //TODO: terminate the actor
+ }
+
+ /**
+ * This method stops the receiver and reports to exception to the tracker.
+ * This should be called whenever an exception has happened on any thread
+ * of the receiver.
+ */
+ protected def stopOnError(e: Exception) {
+ logError("Error receiving data", e)
+ stop()
+ actor ! ReportError(e.toString)
+ }
+
+ /**
+ * This method pushes a block (as iterator of values) into the block manager.
+ */
+ protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) {
+ val buffer = new ArrayBuffer[T] ++ iterator
+ env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level)
+ actor ! ReportBlock(blockId)
+ }
+
+ /**
+ * This method pushes a block (as bytes) into the block manager.
+ */
+ protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ env.blockManager.putBytes(blockId, bytes, level)
+ actor ! ReportBlock(blockId)
+ }
+
+ /** A helper actor that communicates with the NetworkInputTracker */
+ private class NetworkReceiverActor extends Actor {
+ logInfo("Attempting to register with tracker")
+ val ip = System.getProperty("spark.master.host", "localhost")
+ val port = System.getProperty("spark.master.port", "7077").toInt
+ val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
+ val tracker = env.actorSystem.actorFor(url)
+ val timeout = 5.seconds
+
+ override def preStart() {
+ val future = tracker.ask(RegisterReceiver(streamId, self))(timeout)
+ Await.result(future, timeout)
+ }
+
+ override def receive() = {
+ case ReportBlock(blockId) =>
+ tracker ! AddBlocks(streamId, Array(blockId))
+ case ReportError(msg) =>
+ tracker ! DeregisterReceiver(streamId, msg)
+ case StopReceiver(msg) =>
+ stop()
+ tracker ! DeregisterReceiver(streamId, msg)
+ }
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
new file mode 100644
index 0000000000..07ef79415d
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -0,0 +1,118 @@
+package spark.streaming
+
+import spark.Logging
+import spark.SparkEnv
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Queue
+
+import akka.actor._
+import akka.pattern.ask
+import akka.util.duration._
+import akka.dispatch._
+
+trait NetworkInputTrackerMessage
+case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
+case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage
+case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
+
+
+class NetworkInputTracker(
+ @transient ssc: StreamingContext,
+ @transient networkInputStreams: Array[NetworkInputDStream[_]])
+ extends Logging {
+
+ val networkInputStreamIds = networkInputStreams.map(_.id).toArray
+ val receiverExecutor = new ReceiverExecutor()
+ val receiverInfo = new HashMap[Int, ActorRef]
+ val receivedBlockIds = new HashMap[Int, Queue[String]]
+ val timeout = 5000.milliseconds
+
+ var currentTime: Time = null
+
+ def start() {
+ ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
+ receiverExecutor.start()
+ }
+
+ def stop() {
+ receiverExecutor.interrupt()
+ receiverExecutor.stopReceivers()
+ }
+
+ def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized {
+ val queue = receivedBlockIds.synchronized {
+ receivedBlockIds.getOrElse(receiverId, new Queue[String]())
+ }
+ val result = queue.synchronized {
+ queue.dequeueAll(x => true)
+ }
+ result.toArray
+ }
+
+ private class NetworkInputTrackerActor extends Actor {
+ def receive = {
+ case RegisterReceiver(streamId, receiverActor) => {
+ if (!networkInputStreamIds.contains(streamId)) {
+ throw new Exception("Register received for unexpected id " + streamId)
+ }
+ receiverInfo += ((streamId, receiverActor))
+ logInfo("Registered receiver for network stream " + streamId)
+ sender ! true
+ }
+ case AddBlocks(streamId, blockIds) => {
+ val tmp = receivedBlockIds.synchronized {
+ if (!receivedBlockIds.contains(streamId)) {
+ receivedBlockIds += ((streamId, new Queue[String]))
+ }
+ receivedBlockIds(streamId)
+ }
+ tmp.synchronized {
+ tmp ++= blockIds
+ }
+ }
+ case DeregisterReceiver(streamId, msg) => {
+ receiverInfo -= streamId
+ logInfo("De-registered receiver for network stream " + streamId
+ + " with message " + msg)
+ //TODO: Do something about the corresponding NetworkInputDStream
+ }
+ }
+ }
+
+ class ReceiverExecutor extends Thread {
+ val env = ssc.env
+
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ startReceivers()
+ } catch {
+ case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
+ } finally {
+ stopReceivers()
+ }
+ }
+
+ def startReceivers() {
+ val receivers = networkInputStreams.map(_.createReceiver())
+ val tempRDD = ssc.sc.makeRDD(receivers, receivers.size)
+
+ val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => {
+ if (!iterator.hasNext) {
+ throw new Exception("Could not start receiver as details not found.")
+ }
+ iterator.next().start()
+ }
+ ssc.sc.runJob(tempRDD, startReceiver)
+ }
+
+ def stopReceivers() {
+ //implicit val ec = env.actorSystem.dispatcher
+ receiverInfo.values.foreach(_ ! StopReceiver)
+ //val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList
+ //val futureOfList = Future.sequence(listOfFutures)
+ //Await.result(futureOfList, timeout)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
new file mode 100644
index 0000000000..ce1f4ad0a0
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
@@ -0,0 +1,236 @@
+package spark.streaming
+
+import scala.collection.mutable.ArrayBuffer
+import spark.{Manifests, RDD, Partitioner, HashPartitioner}
+import spark.streaming.StreamingContext._
+import javax.annotation.Nullable
+
+class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)])
+extends Serializable {
+
+ def ssc = self.ssc
+
+ def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = {
+ new HashPartitioner(numPartitions)
+ }
+
+ /* ---------------------------------- */
+ /* DStream operations for key-value pairs */
+ /* ---------------------------------- */
+
+ def groupByKey(): DStream[(K, Seq[V])] = {
+ groupByKey(defaultPartitioner())
+ }
+
+ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = {
+ groupByKey(defaultPartitioner(numPartitions))
+ }
+
+ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = {
+ val createCombiner = (v: V) => ArrayBuffer[V](v)
+ val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v)
+ val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2)
+ combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner).asInstanceOf[DStream[(K, Seq[V])]]
+ }
+
+ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = {
+ reduceByKey(reduceFunc, defaultPartitioner())
+ }
+
+ def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = {
+ reduceByKey(reduceFunc, defaultPartitioner(numPartitions))
+ }
+
+ def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = {
+ val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
+ combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner)
+ }
+
+ private def combineByKey[C: ClassManifest](
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiner: (C, C) => C,
+ partitioner: Partitioner) : ShuffledDStream[K, V, C] = {
+ new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner)
+ }
+
+ def groupByKeyAndWindow(windowTime: Time, slideTime: Time): DStream[(K, Seq[V])] = {
+ groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner())
+ }
+
+ def groupByKeyAndWindow(
+ windowTime: Time,
+ slideTime: Time,
+ numPartitions: Int
+ ): DStream[(K, Seq[V])] = {
+ groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner(numPartitions))
+ }
+
+ def groupByKeyAndWindow(
+ windowTime: Time,
+ slideTime: Time,
+ partitioner: Partitioner
+ ): DStream[(K, Seq[V])] = {
+ self.window(windowTime, slideTime).groupByKey(partitioner)
+ }
+
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowTime: Time
+ ): DStream[(K, V)] = {
+ reduceByKeyAndWindow(reduceFunc, windowTime, self.slideTime, defaultPartitioner())
+ }
+
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowTime: Time,
+ slideTime: Time
+ ): DStream[(K, V)] = {
+ reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner())
+ }
+
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowTime: Time,
+ slideTime: Time,
+ numPartitions: Int
+ ): DStream[(K, V)] = {
+ reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions))
+ }
+
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowTime: Time,
+ slideTime: Time,
+ partitioner: Partitioner
+ ): DStream[(K, V)] = {
+ self.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner)
+ }
+
+ // This method is the efficient sliding window reduce operation,
+ // which requires the specification of an inverse reduce function,
+ // so that new elements introduced in the window can be "added" using
+ // reduceFunc to the previous window's result and old elements can be
+ // "subtracted using invReduceFunc.
+
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ windowTime: Time,
+ slideTime: Time
+ ): DStream[(K, V)] = {
+
+ reduceByKeyAndWindow(
+ reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner())
+ }
+
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ windowTime: Time,
+ slideTime: Time,
+ numPartitions: Int
+ ): DStream[(K, V)] = {
+
+ reduceByKeyAndWindow(
+ reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions))
+ }
+
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ windowTime: Time,
+ slideTime: Time,
+ partitioner: Partitioner
+ ): DStream[(K, V)] = {
+
+ val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
+ val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc)
+ new ReducedWindowedDStream[K, V](
+ self, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner)
+ }
+
+ // TODO:
+ //
+ //
+ //
+ //
+ def updateStateByKey[S <: AnyRef : ClassManifest](
+ updateFunc: (Seq[V], Option[S]) => Option[S]
+ ): DStream[(K, S)] = {
+ updateStateByKey(updateFunc, defaultPartitioner())
+ }
+
+ def updateStateByKey[S <: AnyRef : ClassManifest](
+ updateFunc: (Seq[V], Option[S]) => Option[S],
+ numPartitions: Int
+ ): DStream[(K, S)] = {
+ updateStateByKey(updateFunc, defaultPartitioner(numPartitions))
+ }
+
+ def updateStateByKey[S <: AnyRef : ClassManifest](
+ updateFunc: (Seq[V], Option[S]) => Option[S],
+ partitioner: Partitioner
+ ): DStream[(K, S)] = {
+ val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
+ iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
+ }
+ updateStateByKey(newUpdateFunc, partitioner, true)
+ }
+
+ def updateStateByKey[S <: AnyRef : ClassManifest](
+ updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
+ partitioner: Partitioner,
+ rememberPartitioner: Boolean
+ ): DStream[(K, S)] = {
+ new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
+ }
+
+
+ def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = {
+ new MapValuesDStream[K, V, U](self, mapValuesFunc)
+ }
+
+ def flatMapValues[U: ClassManifest](
+ flatMapValuesFunc: V => TraversableOnce[U]
+ ): DStream[(K, U)] = {
+ new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc)
+ }
+
+ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, defaultPartitioner())
+ }
+
+ def cogroup[W: ClassManifest](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (Seq[V], Seq[W]))] = {
+
+ val cgd = new CoGroupedDStream[K](
+ Seq(self.asInstanceOf[DStream[(_, _)]], other.asInstanceOf[DStream[(_, _)]]),
+ partitioner
+ )
+ val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)(
+ classManifest[K],
+ Manifests.seqSeqManifest
+ )
+ pdfs.mapValues {
+ case Seq(vs, ws) =>
+ (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
+ }
+ }
+
+ def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = {
+ join[W](other, defaultPartitioner())
+ }
+
+ def join[W: ClassManifest](other: DStream[(K, W)], partitioner: Partitioner): DStream[(K, (V, W))] = {
+ this.cogroup(other, partitioner)
+ .flatMapValues{
+ case (vs, ws) =>
+ for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
+ }
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala
new file mode 100644
index 0000000000..bb86e51932
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala
@@ -0,0 +1,40 @@
+package spark.streaming
+
+import spark.RDD
+import spark.rdd.UnionRDD
+
+import scala.collection.mutable.Queue
+import scala.collection.mutable.ArrayBuffer
+
+class QueueInputDStream[T: ClassManifest](
+ @transient ssc: StreamingContext,
+ val queue: Queue[RDD[T]],
+ oneAtATime: Boolean,
+ defaultRDD: RDD[T]
+ ) extends InputDStream[T](ssc) {
+
+ override def start() { }
+
+ override def stop() { }
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val buffer = new ArrayBuffer[RDD[T]]()
+ if (oneAtATime && queue.size > 0) {
+ buffer += queue.dequeue()
+ } else {
+ buffer ++= queue
+ }
+ if (buffer.size > 0) {
+ if (oneAtATime) {
+ Some(buffer.first)
+ } else {
+ Some(new UnionRDD(ssc.sc, buffer.toSeq))
+ }
+ } else if (defaultRDD != null) {
+ Some(defaultRDD)
+ } else {
+ None
+ }
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
new file mode 100644
index 0000000000..e022b85fbe
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
@@ -0,0 +1,83 @@
+package spark.streaming
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+import java.nio.channels.{ReadableByteChannel, SocketChannel}
+import java.io.EOFException
+import java.util.concurrent.ArrayBlockingQueue
+import spark._
+import spark.storage.StorageLevel
+
+/**
+ * An input stream that reads blocks of serialized objects from a given network address.
+ * The blocks will be inserted directly into the block store. This is the fastest way to get
+ * data into Spark Streaming, though it requires the sender to batch data and serialize it
+ * in the format that the system is configured with.
+ */
+class RawInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_ ) with Logging {
+
+ def createReceiver(): NetworkReceiver[T] = {
+ new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]]
+ }
+}
+
+class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel)
+ extends NetworkReceiver[Any](streamId) {
+
+ var blockPushingThread: Thread = null
+
+ def onStart() {
+ // Open a socket to the target address and keep reading from it
+ logInfo("Connecting to " + host + ":" + port)
+ val channel = SocketChannel.open()
+ channel.configureBlocking(true)
+ channel.connect(new InetSocketAddress(host, port))
+ logInfo("Connected to " + host + ":" + port)
+
+ val queue = new ArrayBlockingQueue[ByteBuffer](2)
+
+ blockPushingThread = new DaemonThread {
+ override def run() {
+ var nextBlockNumber = 0
+ while (true) {
+ val buffer = queue.take()
+ val blockId = "input-" + streamId + "-" + nextBlockNumber
+ nextBlockNumber += 1
+ pushBlock(blockId, buffer, storageLevel)
+ }
+ }
+ }
+ blockPushingThread.start()
+
+ val lengthBuffer = ByteBuffer.allocate(4)
+ while (true) {
+ lengthBuffer.clear()
+ readFully(channel, lengthBuffer)
+ lengthBuffer.flip()
+ val length = lengthBuffer.getInt()
+ val dataBuffer = ByteBuffer.allocate(length)
+ readFully(channel, dataBuffer)
+ dataBuffer.flip()
+ logInfo("Read a block with " + length + " bytes")
+ queue.put(dataBuffer)
+ }
+ }
+
+ def onStop() {
+ blockPushingThread.interrupt()
+ }
+
+ /** Read a buffer fully from a given Channel */
+ private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) {
+ while (dest.position < dest.limit) {
+ if (channel.read(dest) == -1) {
+ throw new EOFException("End of channel")
+ }
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
new file mode 100644
index 0000000000..1c57d5f855
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
@@ -0,0 +1,143 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+
+import spark.RDD
+import spark.rdd.UnionRDD
+import spark.rdd.CoGroupedRDD
+import spark.Partitioner
+import spark.SparkContext._
+import spark.storage.StorageLevel
+
+import scala.collection.mutable.ArrayBuffer
+import collection.SeqProxy
+
+class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
+ parent: DStream[(K, V)],
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ _windowTime: Time,
+ _slideTime: Time,
+ partitioner: Partitioner
+ ) extends DStream[(K,V)](parent.ssc) {
+
+ if (!_windowTime.isMultipleOf(parent.slideTime))
+ throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+
+ if (!_slideTime.isMultipleOf(parent.slideTime))
+ throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+
+ @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
+
+ def windowTime: Time = _windowTime
+
+ override def dependencies = List(reducedStream)
+
+ override def slideTime: Time = _slideTime
+
+ //TODO: This is wrong. This should depend on the checkpointInterval
+ override def parentRememberDuration: Time = rememberDuration + windowTime
+
+ override def persist(
+ storageLevel: StorageLevel,
+ checkpointLevel: StorageLevel,
+ checkpointInterval: Time): DStream[(K,V)] = {
+ super.persist(storageLevel, checkpointLevel, checkpointInterval)
+ reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval)
+ this
+ }
+
+ protected[streaming] override def setRememberDuration(time: Time) {
+ if (rememberDuration == null || rememberDuration < time) {
+ rememberDuration = time
+ dependencies.foreach(_.setRememberDuration(rememberDuration + windowTime))
+ }
+ }
+
+ override def compute(validTime: Time): Option[RDD[(K, V)]] = {
+ val reduceF = reduceFunc
+ val invReduceF = invReduceFunc
+
+ val currentTime = validTime
+ val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime)
+ val previousWindow = currentWindow - slideTime
+
+ logDebug("Window time = " + windowTime)
+ logDebug("Slide time = " + slideTime)
+ logDebug("ZeroTime = " + zeroTime)
+ logDebug("Current window = " + currentWindow)
+ logDebug("Previous window = " + previousWindow)
+
+ // _____________________________
+ // | previous window _________|___________________
+ // |___________________| current window | --------------> Time
+ // |_____________________________|
+ //
+ // |________ _________| |________ _________|
+ // | |
+ // V V
+ // old RDDs new RDDs
+ //
+
+ // Get the RDDs of the reduced values in "old time steps"
+ val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime)
+ logDebug("# old RDDs = " + oldRDDs.size)
+
+ // Get the RDDs of the reduced values in "new time steps"
+ val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime)
+ logDebug("# new RDDs = " + newRDDs.size)
+
+ // Get the RDD of the reduced value of the previous window
+ val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]()))
+
+ // Make the list of RDDs that needs to cogrouped together for reducing their reduced values
+ val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs
+
+ // Cogroup the reduced RDDs and merge the reduced values
+ val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner)
+ //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _
+
+ val numOldValues = oldRDDs.size
+ val numNewValues = newRDDs.size
+
+ val mergeValues = (seqOfValues: Seq[Seq[V]]) => {
+ if (seqOfValues.size != 1 + numOldValues + numNewValues) {
+ throw new Exception("Unexpected number of sequences of reduced values")
+ }
+ // Getting reduced values "old time steps" that will be removed from current window
+ val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head)
+ // Getting reduced values "new time steps"
+ val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
+ if (seqOfValues(0).isEmpty) {
+ // If previous window's reduce value does not exist, then at least new values should exist
+ if (newValues.isEmpty) {
+ throw new Exception("Neither previous window has value for key, nor new values found")
+ }
+ // Reduce the new values
+ newValues.reduce(reduceF) // return
+ } else {
+ // Get the previous window's reduced value
+ var tempValue = seqOfValues(0).head
+ // If old values exists, then inverse reduce then from previous value
+ if (!oldValues.isEmpty) {
+ tempValue = invReduceF(tempValue, oldValues.reduce(reduceF))
+ }
+ // If new values exists, then reduce them with previous value
+ if (!newValues.isEmpty) {
+ tempValue = reduceF(tempValue, newValues.reduce(reduceF))
+ }
+ tempValue // return
+ }
+ }
+
+ val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues)
+
+ Some(mergedValuesRDD)
+ }
+
+
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
new file mode 100644
index 0000000000..7d52e2eddf
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -0,0 +1,69 @@
+package spark.streaming
+
+import util.{ManualClock, RecurringTimer, Clock}
+import spark.SparkEnv
+import spark.Logging
+
+import scala.collection.mutable.HashMap
+
+
+sealed trait SchedulerMessage
+case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage
+
+class Scheduler(ssc: StreamingContext)
+extends Logging {
+
+ initLogging()
+
+ val graph = ssc.graph
+ val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt
+ val jobManager = new JobManager(ssc, concurrentJobs)
+ val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock")
+ val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
+ val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_))
+
+ def start() {
+ // If context was started from checkpoint, then restart timer such that
+ // this timer's triggers occur at the same time as the original timer.
+ // Otherwise just start the timer from scratch, and initialize graph based
+ // on this first trigger time of the timer.
+ if (ssc.isCheckpointPresent) {
+ // If manual clock is being used for testing, then
+ // set manual clock to the last checkpointed time
+ if (clock.isInstanceOf[ManualClock]) {
+ val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds
+ clock.asInstanceOf[ManualClock].setTime(lastTime)
+ }
+ timer.restart(graph.zeroTime.milliseconds)
+ logInfo("Scheduler's timer restarted")
+ } else {
+ val firstTime = Time(timer.start())
+ graph.start(firstTime - ssc.graph.batchDuration)
+ logInfo("Scheduler's timer started")
+ }
+ logInfo("Scheduler started")
+ }
+
+ def stop() {
+ timer.stop()
+ graph.stop()
+ logInfo("Scheduler stopped")
+ }
+
+ def generateRDDs(time: Time) {
+ SparkEnv.set(ssc.env)
+ logInfo("\n-----------------------------------------------------\n")
+ graph.generateRDDs(time).foreach(submitJob)
+ logInfo("Generated RDDs for time " + time)
+ graph.forgetOldRDDs(time)
+ if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) {
+ ssc.doCheckpoint(time)
+ logInfo("Checkpointed at time " + time)
+ }
+ }
+
+ def submitJob(job: Job) {
+ jobManager.runJob(job)
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
new file mode 100644
index 0000000000..b566200273
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
@@ -0,0 +1,173 @@
+package spark.streaming
+
+import spark.streaming.util.{RecurringTimer, SystemClock}
+import spark.storage.StorageLevel
+
+import java.io._
+import java.net.Socket
+import java.util.concurrent.ArrayBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+import scala.Serializable
+
+class SocketInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ bytesToObjects: InputStream => Iterator[T],
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_) {
+
+ def createReceiver(): NetworkReceiver[T] = {
+ new SocketReceiver(id, host, port, bytesToObjects, storageLevel)
+ }
+}
+
+
+class SocketReceiver[T: ClassManifest](
+ streamId: Int,
+ host: String,
+ port: Int,
+ bytesToObjects: InputStream => Iterator[T],
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[T](streamId) {
+
+ lazy protected val dataHandler = new DataHandler(this)
+
+ protected def onStart() {
+ logInfo("Connecting to " + host + ":" + port)
+ val socket = new Socket(host, port)
+ logInfo("Connected to " + host + ":" + port)
+ dataHandler.start()
+ val iterator = bytesToObjects(socket.getInputStream())
+ while(iterator.hasNext) {
+ val obj = iterator.next
+ dataHandler += obj
+ }
+ }
+
+ protected def onStop() {
+ dataHandler.stop()
+ }
+
+ /**
+ * This is a helper object that manages the data received from the socket. It divides
+ * the object received into small batches of 100s of milliseconds, pushes them as
+ * blocks into the block manager and reports the block IDs to the network input
+ * tracker. It starts two threads, one to periodically start a new batch and prepare
+ * the previous batch of as a block, the other to push the blocks into the block
+ * manager.
+ */
+ class DataHandler(receiver: NetworkReceiver[T]) extends Serializable {
+ case class Block(id: String, iterator: Iterator[T])
+
+ val clock = new SystemClock()
+ val blockInterval = 200L
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockStorageLevel = storageLevel
+ val blocksForPushing = new ArrayBlockingQueue[Block](1000)
+ val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
+
+ var currentBuffer = new ArrayBuffer[T]
+
+ def start() {
+ blockIntervalTimer.start()
+ blockPushingThread.start()
+ logInfo("Data handler started")
+ }
+
+ def stop() {
+ blockIntervalTimer.stop()
+ blockPushingThread.interrupt()
+ logInfo("Data handler stopped")
+ }
+
+ def += (obj: T) {
+ currentBuffer += obj
+ }
+
+ def updateCurrentBuffer(time: Long) {
+ try {
+ val newBlockBuffer = currentBuffer
+ currentBuffer = new ArrayBuffer[T]
+ if (newBlockBuffer.size > 0) {
+ val blockId = "input-" + streamId + "- " + (time - blockInterval)
+ val newBlock = new Block(blockId, newBlockBuffer.toIterator)
+ blocksForPushing.add(newBlock)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block interval timer thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+
+ def keepPushingBlocks() {
+ logInfo("Block pushing thread started")
+ try {
+ while(true) {
+ val block = blocksForPushing.take()
+ pushBlock(block.id, block.iterator, storageLevel)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block pushing thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+ }
+}
+
+
+object SocketReceiver {
+
+ /**
+ * This methods translates the data from an inputstream (say, from a socket)
+ * to '\n' delimited strings and returns an iterator to access the strings.
+ */
+ def bytesToLines(inputStream: InputStream): Iterator[String] = {
+ val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))
+
+ val iterator = new Iterator[String] {
+ var gotNext = false
+ var finished = false
+ var nextValue: String = null
+
+ private def getNext() {
+ try {
+ nextValue = dataInputStream.readLine()
+ if (nextValue == null) {
+ finished = true
+ }
+ }
+ gotNext = true
+ }
+
+ override def hasNext: Boolean = {
+ if (!finished) {
+ if (!gotNext) {
+ getNext()
+ if (finished) {
+ dataInputStream.close()
+ }
+ }
+ }
+ !finished
+ }
+
+ override def next(): String = {
+ if (finished) {
+ throw new NoSuchElementException("End of stream")
+ }
+ if (!gotNext) {
+ getNext()
+ }
+ gotNext = false
+ nextValue
+ }
+ }
+ iterator
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala
new file mode 100644
index 0000000000..086752ac55
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala
@@ -0,0 +1,130 @@
+package spark.streaming
+
+import spark.RDD
+import spark.rdd.BlockRDD
+import spark.Partitioner
+import spark.rdd.MapPartitionsRDD
+import spark.SparkContext._
+import spark.storage.StorageLevel
+
+
+class StateRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: Iterator[T] => Iterator[U],
+ rememberPartitioner: Boolean
+ ) extends MapPartitionsRDD[U, T](prev, f) {
+ override val partitioner = if (rememberPartitioner) prev.partitioner else None
+}
+
+class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest](
+ parent: DStream[(K, V)],
+ updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
+ partitioner: Partitioner,
+ rememberPartitioner: Boolean
+ ) extends DStream[(K, S)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime = parent.slideTime
+
+ override def getOrCompute(time: Time): Option[RDD[(K, S)]] = {
+ generatedRDDs.get(time) match {
+ case Some(oldRDD) => {
+ if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) {
+ val r = oldRDD
+ val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index)
+ val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) {
+ override val partitioner = oldRDD.partitioner
+ }
+ generatedRDDs.update(time, checkpointedRDD)
+ logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id)
+ Some(checkpointedRDD)
+ } else {
+ Some(oldRDD)
+ }
+ }
+ case None => {
+ if (isTimeValid(time)) {
+ compute(time) match {
+ case Some(newRDD) => {
+ if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
+ newRDD.persist(checkpointLevel)
+ logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
+ } else if (storageLevel != StorageLevel.NONE) {
+ newRDD.persist(storageLevel)
+ logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
+ }
+ generatedRDDs.put(time, newRDD)
+ Some(newRDD)
+ }
+ case None => {
+ None
+ }
+ }
+ } else {
+ None
+ }
+ }
+ }
+ }
+
+ override def compute(validTime: Time): Option[RDD[(K, S)]] = {
+
+ // Try to get the previous state RDD
+ getOrCompute(validTime - slideTime) match {
+
+ case Some(prevStateRDD) => { // If previous state RDD exists
+
+ // Try to get the parent RDD
+ parent.getOrCompute(validTime) match {
+ case Some(parentRDD) => { // If parent RDD exists, then compute as usual
+
+ // Define the function for the mapPartition operation on cogrouped RDD;
+ // first map the cogrouped tuple to tuples of required type,
+ // and then apply the update function
+ val updateFuncLocal = updateFunc
+ val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => {
+ val i = iterator.map(t => {
+ (t._1, t._2._1, t._2._2.headOption)
+ })
+ updateFuncLocal(i)
+ }
+ val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
+ val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner)
+ //logDebug("Generating state RDD for time " + validTime)
+ return Some(stateRDD)
+ }
+ case None => { // If parent RDD does not exist, then return old state RDD
+ return Some(prevStateRDD)
+ }
+ }
+ }
+
+ case None => { // If previous session RDD does not exist (first input data)
+
+ // Try to get the parent RDD
+ parent.getOrCompute(validTime) match {
+ case Some(parentRDD) => { // If parent RDD exists, then compute as usual
+
+ // Define the function for the mapPartition operation on grouped RDD;
+ // first map the grouped tuple to tuples of required type,
+ // and then apply the update function
+ val updateFuncLocal = updateFunc
+ val finalFunc = (iterator: Iterator[(K, Seq[V])]) => {
+ updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None)))
+ }
+
+ val groupedRDD = parentRDD.groupByKey(partitioner)
+ val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner)
+ //logDebug("Generating state RDD for time " + validTime + " (first)")
+ return Some(sessionRDD)
+ }
+ case None => { // If parent RDD does not exist, then nothing to do!
+ //logDebug("Not generating state RDD (no previous state, no parent)")
+ return None
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
new file mode 100644
index 0000000000..7c7b3afe47
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -0,0 +1,229 @@
+package spark.streaming
+
+import spark.RDD
+import spark.Logging
+import spark.SparkEnv
+import spark.SparkContext
+import spark.storage.StorageLevel
+
+import scala.collection.mutable.Queue
+
+import java.io.InputStream
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
+
+class StreamingContext (
+ sc_ : SparkContext,
+ cp_ : Checkpoint
+ ) extends Logging {
+
+ def this(sparkContext: SparkContext) = this(sparkContext, null)
+
+ def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) =
+ this(new SparkContext(master, frameworkName, sparkHome, jars), null)
+
+ def this(file: String) = this(null, Checkpoint.loadFromFile(file))
+
+ def this(cp_ : Checkpoint) = this(null, cp_)
+
+ initLogging()
+
+ if (sc_ == null && cp_ == null) {
+ throw new Exception("Streaming Context cannot be initilalized with " +
+ "both SparkContext and checkpoint as null")
+ }
+
+ val isCheckpointPresent = (cp_ != null)
+
+ val sc: SparkContext = {
+ if (isCheckpointPresent) {
+ new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars)
+ } else {
+ sc_
+ }
+ }
+
+ val env = SparkEnv.get
+
+ val graph: DStreamGraph = {
+ if (isCheckpointPresent) {
+
+ cp_.graph.setContext(this)
+ cp_.graph
+ } else {
+ new DStreamGraph()
+ }
+ }
+
+ val nextNetworkInputStreamId = new AtomicInteger(0)
+ var networkInputTracker: NetworkInputTracker = null
+
+ private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null
+ private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null
+ private[streaming] var receiverJobThread: Thread = null
+ private[streaming] var scheduler: Scheduler = null
+
+ def setBatchDuration(duration: Time) {
+ graph.setBatchDuration(duration)
+ }
+
+ def setRememberDuration(duration: Time) {
+ graph.setRememberDuration(duration)
+ }
+
+ def setCheckpointDetails(file: String, interval: Time) {
+ checkpointFile = file
+ checkpointInterval = interval
+ }
+
+ private[streaming] def getInitialCheckpoint(): Checkpoint = {
+ if (isCheckpointPresent) cp_ else null
+ }
+
+ private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+
+ def networkTextStream(
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
+ ): DStream[String] = {
+ networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel)
+ }
+
+ def networkStream[T: ClassManifest](
+ hostname: String,
+ port: Int,
+ converter: (InputStream) => Iterator[T],
+ storageLevel: StorageLevel
+ ): DStream[T] = {
+ val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
+ def rawNetworkStream[T: ClassManifest](
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
+ ): DStream[T] = {
+ val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * This function creates a input stream that monitors a Hadoop-compatible
+ * for new files and executes the necessary processing on them.
+ */
+ def fileStream[
+ K: ClassManifest,
+ V: ClassManifest,
+ F <: NewInputFormat[K, V]: ClassManifest
+ ](directory: String): DStream[(K, V)] = {
+ val inputStream = new FileInputDStream[K, V, F](this, directory)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
+ def textFileStream(directory: String): DStream[String] = {
+ fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString)
+ }
+
+ /**
+ * This function create a input stream from an queue of RDDs. In each batch,
+ * it will process either one or all of the RDDs returned by the queue
+ */
+ def queueStream[T: ClassManifest](
+ queue: Queue[RDD[T]],
+ oneAtATime: Boolean = true,
+ defaultRDD: RDD[T] = null
+ ): DStream[T] = {
+ val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
+ def queueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = {
+ val queue = new Queue[RDD[T]]
+ val inputStream = queueStream(queue, true, null)
+ queue ++= array
+ inputStream
+ }
+
+ /**
+ * This function registers a InputDStream as an input stream that will be
+ * started (InputDStream.start() called) to get the input data streams.
+ */
+ def registerInputStream(inputStream: InputDStream[_]) {
+ graph.addInputStream(inputStream)
+ }
+
+ /**
+ * This function registers a DStream as an output stream that will be
+ * computed every interval.
+ */
+ def registerOutputStream(outputStream: DStream[_]) {
+ graph.addOutputStream(outputStream)
+ }
+
+ def validate() {
+ assert(graph != null, "Graph is null")
+ graph.validate()
+ }
+
+ /**
+ * This function starts the execution of the streams.
+ */
+ def start() {
+ validate()
+
+ val networkInputStreams = graph.getInputStreams().filter(s => s match {
+ case n: NetworkInputDStream[_] => true
+ case _ => false
+ }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
+
+ if (networkInputStreams.length > 0) {
+ // Start the network input tracker (must start before receivers)
+ networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
+ networkInputTracker.start()
+ }
+
+ Thread.sleep(1000)
+
+ // Start the scheduler
+ scheduler = new Scheduler(this)
+ scheduler.start()
+ }
+
+ /**
+ * This function stops the execution of the streams.
+ */
+ def stop() {
+ try {
+ if (scheduler != null) scheduler.stop()
+ if (networkInputTracker != null) networkInputTracker.stop()
+ if (receiverJobThread != null) receiverJobThread.interrupt()
+ sc.stop()
+ } catch {
+ case e: Exception => logWarning("Error while stopping", e)
+ }
+
+ logInfo("StreamingContext stopped")
+ }
+
+ def doCheckpoint(currentTime: Time) {
+ new Checkpoint(this, currentTime).saveToFile(checkpointFile)
+ }
+}
+
+
+object StreamingContext {
+ implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = {
+ new PairDStreamFunctions[K, V](stream)
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala
new file mode 100644
index 0000000000..9ddb65249a
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Time.scala
@@ -0,0 +1,56 @@
+package spark.streaming
+
+case class Time(millis: Long) {
+
+ def < (that: Time): Boolean = (this.millis < that.millis)
+
+ def <= (that: Time): Boolean = (this.millis <= that.millis)
+
+ def > (that: Time): Boolean = (this.millis > that.millis)
+
+ def >= (that: Time): Boolean = (this.millis >= that.millis)
+
+ def + (that: Time): Time = Time(millis + that.millis)
+
+ def - (that: Time): Time = Time(millis - that.millis)
+
+ def * (times: Int): Time = Time(millis * times)
+
+ def floor(that: Time): Time = {
+ val t = that.millis
+ val m = math.floor(this.millis / t).toLong
+ Time(m * t)
+ }
+
+ def isMultipleOf(that: Time): Boolean =
+ (this.millis % that.millis == 0)
+
+ def isZero: Boolean = (this.millis == 0)
+
+ override def toString: String = (millis.toString + " ms")
+
+ def toFormattedString: String = millis.toString
+
+ def milliseconds: Long = millis
+}
+
+object Time {
+ val zero = Time(0)
+
+ implicit def toTime(long: Long) = Time(long)
+
+ implicit def toLong(time: Time) = time.milliseconds
+}
+
+object Milliseconds {
+ def apply(milliseconds: Long) = Time(milliseconds)
+}
+
+object Seconds {
+ def apply(seconds: Long) = Time(seconds * 1000)
+}
+
+object Minutes {
+ def apply(minutes: Long) = Time(minutes * 60000)
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala
new file mode 100644
index 0000000000..ce89a3f99b
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala
@@ -0,0 +1,36 @@
+package spark.streaming
+
+import spark.RDD
+import spark.rdd.UnionRDD
+
+
+class WindowedDStream[T: ClassManifest](
+ parent: DStream[T],
+ _windowTime: Time,
+ _slideTime: Time)
+ extends DStream[T](parent.ssc) {
+
+ if (!_windowTime.isMultipleOf(parent.slideTime))
+ throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+
+ if (!_slideTime.isMultipleOf(parent.slideTime))
+ throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+
+ def windowTime: Time = _windowTime
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = _slideTime
+
+ override def parentRememberDuration: Time = rememberDuration + windowTime
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime)
+ Some(new UnionRDD(ssc.sc, parent.slice(currentWindow)))
+ }
+}
+
+
+
diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala
new file mode 100644
index 0000000000..d2fdabd659
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala
@@ -0,0 +1,32 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.storage.StorageLevel
+import spark.streaming._
+import spark.streaming.StreamingContext._
+
+object CountRaw {
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println("Usage: CountRaw <master> <numStreams> <host> <port> <batchMillis>")
+ System.exit(1)
+ }
+
+ val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "CountRaw")
+ ssc.setBatchDuration(Milliseconds(batchMillis))
+
+ // Make sure some tasks have started on each node
+ ssc.sc.parallelize(1 to 1000, 1000).count()
+ ssc.sc.parallelize(1 to 1000, 1000).count()
+ ssc.sc.parallelize(1 to 1000, 1000).count()
+
+ val rawStreams = (1 to numStreams).map(_ =>
+ ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
+ val union = new UnionDStream(rawStreams)
+ union.map(_.length + 2).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString))
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala
new file mode 100644
index 0000000000..d68611abd6
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala
@@ -0,0 +1,47 @@
+package spark.streaming.examples
+
+import spark.streaming.StreamingContext
+import spark.streaming.StreamingContext._
+import spark.streaming.Seconds
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
+
+
+object FileStream {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: FileStream <master> <new HDFS compatible directory>")
+ System.exit(1)
+ }
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(args(0), "FileStream")
+ ssc.setBatchDuration(Seconds(2))
+
+ // Create the new directory
+ val directory = new Path(args(1))
+ val fs = directory.getFileSystem(new Configuration())
+ if (fs.exists(directory)) throw new Exception("This directory already exists")
+ fs.mkdirs(directory)
+ fs.deleteOnExit(directory)
+
+ // Create the FileInputDStream on the directory and use the
+ // stream to count words in new files created
+ val inputStream = ssc.textFileStream(directory.toString)
+ val words = inputStream.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+
+ // Creating new files in the directory
+ val text = "This is a text file"
+ for (i <- 1 to 30) {
+ ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10)
+ .saveAsTextFile(new Path(directory, i.toString).toString)
+ Thread.sleep(1000)
+ }
+ Thread.sleep(5000) // Waiting for the file to be processed
+ ssc.stop()
+ System.exit(0)
+ }
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
new file mode 100644
index 0000000000..df96a811da
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
@@ -0,0 +1,76 @@
+package spark.streaming.examples
+
+import spark.streaming._
+import spark.streaming.StreamingContext._
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
+
+object FileStreamWithCheckpoint {
+
+ def main(args: Array[String]) {
+
+ if (args.size != 3) {
+ println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>")
+ println("FileStreamWithCheckpoint restart <directory> <checkpoint file>")
+ System.exit(-1)
+ }
+
+ val directory = new Path(args(1))
+ val checkpointFile = args(2)
+
+ val ssc: StreamingContext = {
+
+ if (args(0) == "restart") {
+
+ // Recreated streaming context from specified checkpoint file
+ new StreamingContext(checkpointFile)
+
+ } else {
+
+ // Create directory if it does not exist
+ val fs = directory.getFileSystem(new Configuration())
+ if (!fs.exists(directory)) fs.mkdirs(directory)
+
+ // Create new streaming context
+ val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint")
+ ssc_.setBatchDuration(Seconds(1))
+ ssc_.setCheckpointDetails(checkpointFile, Seconds(1))
+
+ // Setup the streaming computation
+ val inputStream = ssc_.textFileStream(directory.toString)
+ val words = inputStream.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+
+ ssc_
+ }
+ }
+
+ // Start the stream computation
+ startFileWritingThread(directory.toString)
+ ssc.start()
+ }
+
+ def startFileWritingThread(directory: String) {
+
+ val fs = new Path(directory).getFileSystem(new Configuration())
+
+ val fileWritingThread = new Thread() {
+ override def run() {
+ val r = new scala.util.Random()
+ val text = "This is a sample text file with a random number "
+ while(true) {
+ val number = r.nextInt()
+ val file = new Path(directory, number.toString)
+ val fos = fs.create(file)
+ fos.writeChars(text + number)
+ fos.close()
+ println("Created text file " + file)
+ Thread.sleep(1000)
+ }
+ }
+ }
+ fileWritingThread.start()
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala
new file mode 100644
index 0000000000..b1faa65c17
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala
@@ -0,0 +1,64 @@
+package spark.streaming.examples
+
+import spark.SparkContext
+import SparkContext._
+import spark.streaming._
+import StreamingContext._
+
+import spark.storage.StorageLevel
+
+import scala.util.Sorting
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Queue
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+
+object Grep2 {
+
+ def warmup(sc: SparkContext) {
+ (0 until 10).foreach {i =>
+ sc.parallelize(1 to 20000000, 1000)
+ .map(x => (x % 337, x % 1331))
+ .reduceByKey(_ + _)
+ .count()
+ }
+ }
+
+ def main (args: Array[String]) {
+
+ if (args.length != 6) {
+ println ("Usage: Grep2 <host> <file> <mapTasks> <reduceTasks> <batchMillis> <chkptMillis>")
+ System.exit(1)
+ }
+
+ val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args
+
+ val batchDuration = Milliseconds(batchMillis.toLong)
+
+ val ssc = new StreamingContext(master, "Grep2")
+ ssc.setBatchDuration(batchDuration)
+
+ //warmup(ssc.sc)
+
+ val data = ssc.sc.textFile(file, mapTasks.toInt).persist(
+ new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas
+ println("Data count: " + data.count())
+ println("Data count: " + data.count())
+ println("Data count: " + data.count())
+
+ val sentences = new ConstantInputDStream(ssc, data)
+ ssc.registerInputStream(sentences)
+
+ sentences.filter(_.contains("Culpepper")).count().foreachRDD(r =>
+ println("Grep count: " + r.collect().mkString))
+
+ ssc.start()
+
+ while(true) { Thread.sleep(1000) }
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala
new file mode 100644
index 0000000000..b1e1a613fe
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala
@@ -0,0 +1,33 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.storage.StorageLevel
+import spark.streaming._
+import spark.streaming.StreamingContext._
+
+object GrepRaw {
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println("Usage: GrepRaw <master> <numStreams> <host> <port> <batchMillis>")
+ System.exit(1)
+ }
+
+ val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "GrepRaw")
+ ssc.setBatchDuration(Milliseconds(batchMillis))
+
+ // Make sure some tasks have started on each node
+ ssc.sc.parallelize(1 to 1000, 1000).count()
+ ssc.sc.parallelize(1 to 1000, 1000).count()
+ ssc.sc.parallelize(1 to 1000, 1000).count()
+
+ val rawStreams = (1 to numStreams).map(_ =>
+ ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
+ val union = new UnionDStream(rawStreams)
+ union.filter(_.contains("Culpepper")).count().foreachRDD(r =>
+ println("Grep count: " + r.collect().mkString))
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala
new file mode 100644
index 0000000000..2af51bad28
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala
@@ -0,0 +1,41 @@
+package spark.streaming.examples
+
+import spark.RDD
+import spark.streaming.StreamingContext
+import spark.streaming.StreamingContext._
+import spark.streaming.Seconds
+
+import scala.collection.mutable.SynchronizedQueue
+
+object QueueStream {
+
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ System.err.println("Usage: QueueStream <master>")
+ System.exit(1)
+ }
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(args(0), "QueueStream")
+ ssc.setBatchDuration(Seconds(1))
+
+ // Create the queue through which RDDs can be pushed to
+ // a QueueInputDStream
+ val rddQueue = new SynchronizedQueue[RDD[Int]]()
+
+ // Create the QueueInputDStream and use it do some processing
+ val inputStream = ssc.queueStream(rddQueue)
+ val mappedStream = inputStream.map(x => (x % 10, 1))
+ val reducedStream = mappedStream.reduceByKey(_ + _)
+ reducedStream.print()
+ ssc.start()
+
+ // Create and push some RDDs into
+ for (i <- 1 to 30) {
+ rddQueue += ssc.sc.makeRDD(1 to 1000, 10)
+ Thread.sleep(1000)
+ }
+ ssc.stop()
+ System.exit(0)
+ }
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
new file mode 100644
index 0000000000..57fd10f0a5
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
@@ -0,0 +1,95 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.SparkContext
+import spark.SparkContext._
+import spark.storage.StorageLevel
+import spark.streaming._
+import spark.streaming.StreamingContext._
+
+import WordCount2_ExtraFunctions._
+
+object TopKWordCountRaw {
+ def moreWarmup(sc: SparkContext) {
+ (0 until 40).foreach {i =>
+ sc.parallelize(1 to 20000000, 1000)
+ .map(_ % 1331).map(_.toString)
+ .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10)
+ .collect()
+ }
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 7) {
+ System.err.println("Usage: TopKWordCountRaw <master> <streams> <host> <port> <batchMs> <chkptMs> <reduces>")
+ System.exit(1)
+ }
+
+ val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs),
+ IntParam(chkptMs), IntParam(reduces)) = args
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "TopKWordCountRaw")
+ ssc.setBatchDuration(Milliseconds(batchMs))
+
+ // Make sure some tasks have started on each node
+ moreWarmup(ssc.sc)
+
+ val rawStreams = (1 to streams).map(_ =>
+ ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
+ val union = new UnionDStream(rawStreams)
+
+ val windowedCounts = union.mapPartitions(splitAndCountPartitions)
+ .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
+ windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
+ Milliseconds(chkptMs))
+ //windowedCounts.print() // TODO: something else?
+
+ def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = {
+ val taken = new Array[(String, Long)](k)
+
+ var i = 0
+ var len = 0
+ var done = false
+ var value: (String, Long) = null
+ var swap: (String, Long) = null
+ var count = 0
+
+ while(data.hasNext) {
+ value = data.next
+ count += 1
+ println("count = " + count)
+ if (len == 0) {
+ taken(0) = value
+ len = 1
+ } else if (len < k || value._2 > taken(len - 1)._2) {
+ if (len < k) {
+ len += 1
+ }
+ taken(len - 1) = value
+ i = len - 1
+ while(i > 0 && taken(i - 1)._2 < taken(i)._2) {
+ swap = taken(i)
+ taken(i) = taken(i-1)
+ taken(i - 1) = swap
+ i -= 1
+ }
+ }
+ }
+ println("Took " + len + " out of " + count + " items")
+ return taken.toIterator
+ }
+
+ val k = 50
+ val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k))
+ partialTopKWindowedCounts.foreachRDD(rdd => {
+ val collectedCounts = rdd.collect
+ println("Collected " + collectedCounts.size + " items")
+ topK(collectedCounts.toIterator, k).foreach(println)
+ })
+
+// windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
+
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
new file mode 100644
index 0000000000..0d2e62b955
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
@@ -0,0 +1,115 @@
+package spark.streaming.examples
+
+import spark.SparkContext
+import SparkContext._
+import spark.streaming._
+import StreamingContext._
+
+import spark.storage.StorageLevel
+
+import scala.util.Sorting
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Queue
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+
+object WordCount2_ExtraFunctions {
+
+ def add(v1: Long, v2: Long) = (v1 + v2)
+
+ def subtract(v1: Long, v2: Long) = (v1 - v2)
+
+ def max(v1: Long, v2: Long) = math.max(v1, v2)
+
+ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = {
+ //val map = new java.util.HashMap[String, Long]
+ val map = new OLMap[String]
+ var i = 0
+ var j = 0
+ while (iter.hasNext) {
+ val s = iter.next()
+ i = 0
+ while (i < s.length) {
+ j = i
+ while (j < s.length && s.charAt(j) != ' ') {
+ j += 1
+ }
+ if (j > i) {
+ val w = s.substring(i, j)
+ val c = map.getLong(w)
+ map.put(w, c + 1)
+/*
+ if (c == null) {
+ map.put(w, 1)
+ } else {
+ map.put(w, c + 1)
+ }
+*/
+ }
+ i = j
+ while (i < s.length && s.charAt(i) == ' ') {
+ i += 1
+ }
+ }
+ }
+ map.toIterator.map{case (k, v) => (k, v)}
+ }
+}
+
+object WordCount2 {
+
+ def warmup(sc: SparkContext) {
+ (0 until 3).foreach {i =>
+ sc.parallelize(1 to 20000000, 500)
+ .map(x => (x % 337, x % 1331))
+ .reduceByKey(_ + _, 100)
+ .count()
+ }
+ }
+
+ def main (args: Array[String]) {
+
+ if (args.length != 6) {
+ println ("Usage: WordCount2 <host> <file> <mapTasks> <reduceTasks> <batchMillis> <chkptMillis>")
+ System.exit(1)
+ }
+
+ val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args
+
+ val batchDuration = Milliseconds(batchMillis.toLong)
+
+ val ssc = new StreamingContext(master, "WordCount2")
+ ssc.setBatchDuration(batchDuration)
+
+ //warmup(ssc.sc)
+
+ val data = ssc.sc.textFile(file, mapTasks.toInt).persist(
+ new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas
+ println("Data count: " + data.map(x => if (x == "") 1 else x.split(" ").size / x.split(" ").size).count())
+ println("Data count: " + data.count())
+ println("Data count: " + data.count())
+
+ val sentences = new ConstantInputDStream(ssc, data)
+ ssc.registerInputStream(sentences)
+
+ import WordCount2_ExtraFunctions._
+
+ val windowedCounts = sentences
+ .mapPartitions(splitAndCountPartitions)
+ .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt)
+ windowedCounts.persist(StorageLevel.MEMORY_ONLY,
+ StorageLevel.MEMORY_ONLY_2,
+ //new StorageLevel(false, true, true, 3),
+ Milliseconds(chkptMillis.toLong))
+ windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
+
+ ssc.start()
+
+ while(true) { Thread.sleep(1000) }
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala
new file mode 100644
index 0000000000..591cb141c3
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala
@@ -0,0 +1,26 @@
+package spark.streaming.examples
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+
+object WordCountHdfs {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: WordCountHdfs <master> <directory>")
+ System.exit(1)
+ }
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(args(0), "WordCountHdfs")
+ ssc.setBatchDuration(Seconds(2))
+
+ // Create the FileInputDStream on the directory and use the
+ // stream to count words in new files created
+ val lines = ssc.textFileStream(args(1))
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala
new file mode 100644
index 0000000000..ba1bd1de7c
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala
@@ -0,0 +1,25 @@
+package spark.streaming.examples
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+
+object WordCountNetwork {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: WordCountNetwork <master> <hostname> <port>")
+ System.exit(1)
+ }
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(args(0), "WordCountNetwork")
+ ssc.setBatchDuration(Seconds(2))
+
+ // Create a NetworkInputDStream on target ip:port and count the
+ // words in input stream of \n delimited test (eg. generated by 'nc')
+ val lines = ssc.networkTextStream(args(1), args(2).toInt)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
new file mode 100644
index 0000000000..abfd12890f
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
@@ -0,0 +1,51 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.SparkContext
+import spark.SparkContext._
+import spark.storage.StorageLevel
+import spark.streaming._
+import spark.streaming.StreamingContext._
+
+import WordCount2_ExtraFunctions._
+
+object WordCountRaw {
+ def moreWarmup(sc: SparkContext) {
+ (0 until 40).foreach {i =>
+ sc.parallelize(1 to 20000000, 1000)
+ .map(_ % 1331).map(_.toString)
+ .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10)
+ .collect()
+ }
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 7) {
+ System.err.println("Usage: WordCountRaw <master> <streams> <host> <port> <batchMs> <chkptMs> <reduces>")
+ System.exit(1)
+ }
+
+ val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs),
+ IntParam(chkptMs), IntParam(reduces)) = args
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "WordCountRaw")
+ ssc.setBatchDuration(Milliseconds(batchMs))
+
+ // Make sure some tasks have started on each node
+ moreWarmup(ssc.sc)
+
+ val rawStreams = (1 to streams).map(_ =>
+ ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
+ val union = new UnionDStream(rawStreams)
+
+ val windowedCounts = union.mapPartitions(splitAndCountPartitions)
+ .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
+ windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
+ Milliseconds(chkptMs))
+ //windowedCounts.print() // TODO: something else?
+ windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
+
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
new file mode 100644
index 0000000000..9d44da2b11
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
@@ -0,0 +1,73 @@
+package spark.streaming.examples
+
+import spark.SparkContext
+import SparkContext._
+import spark.streaming._
+import StreamingContext._
+
+import spark.storage.StorageLevel
+
+import scala.util.Sorting
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Queue
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+
+object WordMax2 {
+
+ def warmup(sc: SparkContext) {
+ (0 until 10).foreach {i =>
+ sc.parallelize(1 to 20000000, 1000)
+ .map(x => (x % 337, x % 1331))
+ .reduceByKey(_ + _)
+ .count()
+ }
+ }
+
+ def main (args: Array[String]) {
+
+ if (args.length != 6) {
+ println ("Usage: WordMax2 <host> <file> <mapTasks> <reduceTasks> <batchMillis> <chkptMillis>")
+ System.exit(1)
+ }
+
+ val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args
+
+ val batchDuration = Milliseconds(batchMillis.toLong)
+
+ val ssc = new StreamingContext(master, "WordMax2")
+ ssc.setBatchDuration(batchDuration)
+
+ //warmup(ssc.sc)
+
+ val data = ssc.sc.textFile(file, mapTasks.toInt).persist(
+ new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas
+ println("Data count: " + data.count())
+ println("Data count: " + data.count())
+ println("Data count: " + data.count())
+
+ val sentences = new ConstantInputDStream(ssc, data)
+ ssc.registerInputStream(sentences)
+
+ import WordCount2_ExtraFunctions._
+
+ val windowedCounts = sentences
+ .mapPartitions(splitAndCountPartitions)
+ .reduceByKey(add _, reduceTasks.toInt)
+ .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
+ Milliseconds(chkptMillis.toLong))
+ .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt)
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
+ // Milliseconds(chkptMillis.toLong))
+ windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
+
+ ssc.start()
+
+ while(true) { Thread.sleep(1000) }
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala
new file mode 100644
index 0000000000..ed087e4ea8
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala
@@ -0,0 +1,84 @@
+package spark.streaming.util
+
+import spark.streaming._
+
+trait Clock {
+ def currentTime(): Long
+ def waitTillTime(targetTime: Long): Long
+}
+
+
+class SystemClock() extends Clock {
+
+ val minPollTime = 25L
+
+ def currentTime(): Long = {
+ System.currentTimeMillis()
+ }
+
+ def waitTillTime(targetTime: Long): Long = {
+ var currentTime = 0L
+ currentTime = System.currentTimeMillis()
+
+ var waitTime = targetTime - currentTime
+ if (waitTime <= 0) {
+ return currentTime
+ }
+
+ val pollTime = {
+ if (waitTime / 10.0 > minPollTime) {
+ (waitTime / 10.0).toLong
+ } else {
+ minPollTime
+ }
+ }
+
+
+ while (true) {
+ currentTime = System.currentTimeMillis()
+ waitTime = targetTime - currentTime
+
+ if (waitTime <= 0) {
+
+ return currentTime
+ }
+ val sleepTime =
+ if (waitTime < pollTime) {
+ waitTime
+ } else {
+ pollTime
+ }
+ Thread.sleep(sleepTime)
+ }
+ return -1
+ }
+}
+
+class ManualClock() extends Clock {
+
+ var time = 0L
+
+ def currentTime() = time
+
+ def setTime(timeToSet: Long) = {
+ this.synchronized {
+ time = timeToSet
+ this.notifyAll()
+ }
+ }
+
+ def addToTime(timeToAdd: Long) = {
+ this.synchronized {
+ time += timeToAdd
+ this.notifyAll()
+ }
+ }
+ def waitTillTime(targetTime: Long): Long = {
+ this.synchronized {
+ while (time < targetTime) {
+ this.wait(100)
+ }
+ }
+ return currentTime()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala b/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala
new file mode 100644
index 0000000000..cde868a0c9
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala
@@ -0,0 +1,157 @@
+package spark.streaming.util
+
+import spark.Logging
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue}
+
+import java.net._
+import java.io._
+import java.nio._
+import java.nio.charset._
+import java.nio.channels._
+import java.nio.channels.spi._
+
+abstract class ConnectionHandler(host: String, port: Int, connect: Boolean)
+extends Thread with Logging {
+
+ val selector = SelectorProvider.provider.openSelector()
+ val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+
+ initLogging()
+
+ override def run() {
+ try {
+ if (connect) {
+ connect()
+ } else {
+ listen()
+ }
+
+ var interrupted = false
+ while(!interrupted) {
+
+ preSelect()
+
+ while(!interestChangeRequests.isEmpty) {
+ val (key, ops) = interestChangeRequests.dequeue
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ def intToOpStr(op: Int): String = {
+ val opStrs = new ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+
+ selector.select()
+ interrupted = Thread.currentThread.isInterrupted
+
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext) {
+ val key = selectedKeys.next.asInstanceOf[SelectionKey]
+ selectedKeys.remove()
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ accept(key)
+ } else if (key.isConnectable) {
+ finishConnect(key)
+ } else if (key.isReadable) {
+ read(key)
+ } else if (key.isWritable) {
+ write(key)
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logError("Error in select loop", e)
+ }
+ }
+ }
+
+ def connect() {
+ val socketAddress = new InetSocketAddress(host, port)
+ val channel = SocketChannel.open()
+ channel.configureBlocking(false)
+ channel.socket.setReuseAddress(true)
+ channel.socket.setTcpNoDelay(true)
+ channel.connect(socketAddress)
+ channel.register(selector, SelectionKey.OP_CONNECT)
+ logInfo("Initiating connection to [" + socketAddress + "]")
+ }
+
+ def listen() {
+ val channel = ServerSocketChannel.open()
+ channel.configureBlocking(false)
+ channel.socket.setReuseAddress(true)
+ channel.socket.setReceiveBufferSize(256 * 1024)
+ channel.socket.bind(new InetSocketAddress(port))
+ channel.register(selector, SelectionKey.OP_ACCEPT)
+ logInfo("Listening on port " + port)
+ }
+
+ def finishConnect(key: SelectionKey) {
+ try {
+ val channel = key.channel.asInstanceOf[SocketChannel]
+ val address = channel.socket.getRemoteSocketAddress
+ channel.finishConnect()
+ logInfo("Connected to [" + host + ":" + port + "]")
+ ready(key)
+ } catch {
+ case e: IOException => {
+ logError("Error finishing connect to " + host + ":" + port)
+ close(key)
+ }
+ }
+ }
+
+ def accept(key: SelectionKey) {
+ try {
+ val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+ val channel = serverChannel.accept()
+ val address = channel.socket.getRemoteSocketAddress
+ channel.configureBlocking(false)
+ logInfo("Accepted connection from [" + address + "]")
+ ready(channel.register(selector, 0))
+ } catch {
+ case e: IOException => {
+ logError("Error accepting connection", e)
+ }
+ }
+ }
+
+ def changeInterest(key: SelectionKey, ops: Int) {
+ logTrace("Added request to change ops to " + ops)
+ interestChangeRequests += ((key, ops))
+ }
+
+ def ready(key: SelectionKey)
+
+ def preSelect() {
+ }
+
+ def read(key: SelectionKey) {
+ throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
+ }
+
+ def write(key: SelectionKey) {
+ throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
+ }
+
+ def close(key: SelectionKey) {
+ try {
+ key.channel.close()
+ key.cancel()
+ Thread.currentThread.interrupt
+ } catch {
+ case e: Exception => logError("Error closing connection", e)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
new file mode 100644
index 0000000000..d8b987ec86
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
@@ -0,0 +1,60 @@
+package spark.streaming.util
+
+import java.nio.ByteBuffer
+import spark.util.{RateLimitedOutputStream, IntParam}
+import java.net.ServerSocket
+import spark.{Logging, KryoSerializer}
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import io.Source
+import java.io.IOException
+
+/**
+ * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a
+ * specified rate. Used to feed data into RawInputDStream.
+ */
+object RawTextSender extends Logging {
+ def main(args: Array[String]) {
+ if (args.length != 4) {
+ System.err.println("Usage: RawTextSender <port> <file> <blockSize> <bytesPerSec>")
+ System.exit(1)
+ }
+ // Parse the arguments using a pattern match
+ val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args
+
+ // Repeat the input data multiple times to fill in a buffer
+ val lines = Source.fromFile(file).getLines().toArray
+ val bufferStream = new FastByteArrayOutputStream(blockSize + 1000)
+ val ser = new KryoSerializer().newInstance()
+ val serStream = ser.serializeStream(bufferStream)
+ var i = 0
+ while (bufferStream.position < blockSize) {
+ serStream.writeObject(lines(i))
+ i = (i + 1) % lines.length
+ }
+ bufferStream.trim()
+ val array = bufferStream.array
+
+ val countBuf = ByteBuffer.wrap(new Array[Byte](4))
+ countBuf.putInt(array.length)
+ countBuf.flip()
+
+ val serverSocket = new ServerSocket(port)
+ logInfo("Listening on port " + port)
+
+ while (true) {
+ val socket = serverSocket.accept()
+ logInfo("Got a new connection")
+ val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec)
+ try {
+ while (true) {
+ out.write(countBuf.array)
+ out.write(array)
+ }
+ } catch {
+ case e: IOException =>
+ logError("Client disconnected")
+ socket.close()
+ }
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
new file mode 100644
index 0000000000..dc55fd902b
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
@@ -0,0 +1,73 @@
+package spark.streaming.util
+
+class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
+
+ val minPollTime = 25L
+
+ val pollTime = {
+ if (period / 10.0 > minPollTime) {
+ (period / 10.0).toLong
+ } else {
+ minPollTime
+ }
+ }
+
+ val thread = new Thread() {
+ override def run() { loop }
+ }
+
+ var nextTime = 0L
+
+ def start(startTime: Long): Long = {
+ nextTime = startTime
+ thread.start()
+ nextTime
+ }
+
+ def start(): Long = {
+ val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
+ start(startTime)
+ }
+
+ def restart(originalStartTime: Long): Long = {
+ val gap = clock.currentTime - originalStartTime
+ val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
+ start(newStartTime)
+ }
+
+ def stop() {
+ thread.interrupt()
+ }
+
+ def loop() {
+ try {
+ while (true) {
+ clock.waitTillTime(nextTime)
+ callback(nextTime)
+ nextTime += period
+ }
+
+ } catch {
+ case e: InterruptedException =>
+ }
+ }
+}
+
+object RecurringTimer {
+
+ def main(args: Array[String]) {
+ var lastRecurTime = 0L
+ val period = 1000
+
+ def onRecur(time: Long) {
+ val currentTime = System.currentTimeMillis()
+ println("" + currentTime + ": " + (currentTime - lastRecurTime))
+ lastRecurTime = currentTime
+ }
+ val timer = new RecurringTimer(new SystemClock(), period, onRecur)
+ timer.start()
+ Thread.sleep(30 * 1000)
+ timer.stop()
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala
new file mode 100644
index 0000000000..3922dfbad6
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala
@@ -0,0 +1,67 @@
+package spark.streaming.util
+
+import java.net.{Socket, ServerSocket}
+import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream}
+
+object Receiver {
+ def main(args: Array[String]) {
+ val port = args(0).toInt
+ val lsocket = new ServerSocket(port)
+ println("Listening on port " + port )
+ while(true) {
+ val socket = lsocket.accept()
+ (new Thread() {
+ override def run() {
+ val buffer = new Array[Byte](100000)
+ var count = 0
+ val time = System.currentTimeMillis
+ try {
+ val is = new DataInputStream(new BufferedInputStream(socket.getInputStream))
+ var loop = true
+ var string: String = null
+ do {
+ string = is.readUTF()
+ if (string != null) {
+ count += 28
+ }
+ } while (string != null)
+ } catch {
+ case e: Exception => e.printStackTrace()
+ }
+ val timeTaken = System.currentTimeMillis - time
+ val tput = (count / 1024.0) / (timeTaken / 1000.0)
+ println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s")
+ }
+ }).start()
+ }
+ }
+
+}
+
+object Sender {
+
+ def main(args: Array[String]) {
+ try {
+ val host = args(0)
+ val port = args(1).toInt
+ val size = args(2).toInt
+
+ val byteStream = new ByteArrayOutputStream()
+ val stringDataStream = new DataOutputStream(byteStream)
+ (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy"))
+ val bytes = byteStream.toByteArray()
+ println("Generated array of " + bytes.length + " bytes")
+
+ /*val bytes = new Array[Byte](size)*/
+ val socket = new Socket(host, port)
+ val os = socket.getOutputStream
+ os.write(bytes)
+ os.flush
+ socket.close()
+
+ } catch {
+ case e: Exception => e.printStackTrace
+ }
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala
new file mode 100644
index 0000000000..94e8f7a849
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala
@@ -0,0 +1,92 @@
+package spark.streaming.util
+
+import spark._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+import scala.io.Source
+
+import java.net.InetSocketAddress
+
+import org.apache.hadoop.fs._
+import org.apache.hadoop.conf._
+import org.apache.hadoop.io._
+import org.apache.hadoop.mapred._
+import org.apache.hadoop.util._
+
+object SentenceFileGenerator {
+
+ def printUsage () {
+ println ("Usage: SentenceFileGenerator <master> <target directory> <# partitions> <sentence file> [<sentences per second>]")
+ System.exit(0)
+ }
+
+ def main (args: Array[String]) {
+ if (args.length < 4) {
+ printUsage
+ }
+
+ val master = args(0)
+ val fs = new Path(args(1)).getFileSystem(new Configuration())
+ val targetDirectory = new Path(args(1)).makeQualified(fs)
+ val numPartitions = args(2).toInt
+ val sentenceFile = args(3)
+ val sentencesPerSecond = {
+ if (args.length > 4) args(4).toInt
+ else 10
+ }
+
+ val source = Source.fromFile(sentenceFile)
+ val lines = source.mkString.split ("\n").toArray
+ source.close ()
+ println("Read " + lines.length + " lines from file " + sentenceFile)
+
+ val sentences = {
+ val buffer = ArrayBuffer[String]()
+ val random = new Random()
+ var i = 0
+ while (i < sentencesPerSecond) {
+ buffer += lines(random.nextInt(lines.length))
+ i += 1
+ }
+ buffer.toArray
+ }
+ println("Generated " + sentences.length + " sentences")
+
+ val sc = new SparkContext(master, "SentenceFileGenerator")
+ val sentencesRDD = sc.parallelize(sentences, numPartitions)
+
+ val tempDirectory = new Path(targetDirectory, "_tmp")
+
+ fs.mkdirs(targetDirectory)
+ fs.mkdirs(tempDirectory)
+
+ var saveTimeMillis = System.currentTimeMillis
+ try {
+ while (true) {
+ val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis)
+ val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis)
+ println("Writing to file " + newDir)
+ sentencesRDD.saveAsTextFile(tmpNewDir.toString)
+ fs.rename(tmpNewDir, newDir)
+ saveTimeMillis += 1000
+ val sleepTimeMillis = {
+ val currentTimeMillis = System.currentTimeMillis
+ if (saveTimeMillis < currentTimeMillis) {
+ 0
+ } else {
+ saveTimeMillis - currentTimeMillis
+ }
+ }
+ println("Sleeping for " + sleepTimeMillis + " ms")
+ Thread.sleep(sleepTimeMillis)
+ }
+ } catch {
+ case e: Exception =>
+ }
+ }
+}
+
+
+
+
diff --git a/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala
new file mode 100644
index 0000000000..60085f4f88
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala
@@ -0,0 +1,23 @@
+package spark.streaming.util
+
+import spark.SparkContext
+import SparkContext._
+
+object ShuffleTest {
+ def main(args: Array[String]) {
+
+ if (args.length < 1) {
+ println ("Usage: ShuffleTest <host>")
+ System.exit(1)
+ }
+
+ val sc = new spark.SparkContext(args(0), "ShuffleTest")
+ val rdd = sc.parallelize(1 to 1000, 500).cache
+
+ def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) }
+
+ time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } }
+ System.exit(0)
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala
new file mode 100644
index 0000000000..23e9235c60
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala
@@ -0,0 +1,107 @@
+package spark.streaming.util
+
+import scala.util.Random
+import scala.io.Source
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.remote._
+import scala.actors.remote.RemoteActor._
+
+import java.net.InetSocketAddress
+
+
+object TestGenerator {
+
+ def printUsage {
+ println ("Usage: SentenceGenerator <target IP> <target port> <sentence file> [<sentences per second>]")
+ System.exit(0)
+ }
+ /*
+ def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) {
+ val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1
+ val random = new Random ()
+
+ try {
+ var lastPrintTime = System.currentTimeMillis()
+ var count = 0
+ while(true) {
+ streamReceiver ! lines(random.nextInt(lines.length))
+ count += 1
+ if (System.currentTimeMillis - lastPrintTime >= 1000) {
+ println (count + " sentences sent last second")
+ count = 0
+ lastPrintTime = System.currentTimeMillis
+ }
+ Thread.sleep(sleepBetweenSentences.toLong)
+ }
+ } catch {
+ case e: Exception =>
+ }
+ }*/
+
+ def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) {
+ try {
+ val numSentences = if (sentencesPerSecond <= 0) {
+ lines.length
+ } else {
+ sentencesPerSecond
+ }
+ val sentences = lines.take(numSentences).toArray
+
+ var nextSendingTime = System.currentTimeMillis()
+ val sendAsArray = true
+ while(true) {
+ if (sendAsArray) {
+ println("Sending as array")
+ streamReceiver !? sentences
+ } else {
+ println("Sending individually")
+ sentences.foreach(sentence => {
+ streamReceiver !? sentence
+ })
+ }
+ println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms")
+ nextSendingTime += 1000
+ val sleepTime = nextSendingTime - System.currentTimeMillis
+ if (sleepTime > 0) {
+ println ("Sleeping for " + sleepTime + " ms")
+ Thread.sleep(sleepTime)
+ }
+ }
+ } catch {
+ case e: Exception =>
+ }
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ printUsage
+ }
+
+ val generateRandomly = false
+
+ val streamReceiverIP = args(0)
+ val streamReceiverPort = args(1).toInt
+ val sentenceFile = args(2)
+ val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10
+ val sentenceInputName = if (args.length > 4) args(4) else "Sentences"
+
+ println("Sending " + sentencesPerSecond + " sentences per second to " +
+ streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName)
+ val source = Source.fromFile(sentenceFile)
+ val lines = source.mkString.split ("\n")
+ source.close ()
+
+ val streamReceiver = select(
+ Node(streamReceiverIP, streamReceiverPort),
+ Symbol("NetworkStreamReceiver-" + sentenceInputName))
+ if (generateRandomly) {
+ /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/
+ } else {
+ generateSameSentences(lines, sentencesPerSecond, streamReceiver)
+ }
+ }
+}
+
+
+
diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala
new file mode 100644
index 0000000000..ff840d084f
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala
@@ -0,0 +1,119 @@
+package spark.streaming.util
+
+import scala.util.Random
+import scala.io.Source
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.remote._
+import scala.actors.remote.RemoteActor._
+
+import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream}
+import java.net.Socket
+
+object TestGenerator2 {
+
+ def printUsage {
+ println ("Usage: SentenceGenerator <target IP> <target port> <sentence file> [<sentences per second>]")
+ System.exit(0)
+ }
+
+ def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){
+ try {
+ println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort)
+ val socket = new Socket(streamReceiverHost, streamReceiverPort)
+
+ println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort )
+ val currentTime = System.currentTimeMillis
+ var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime
+ Thread.sleep(targetTime - currentTime)
+
+ while(true) {
+ val startTime = System.currentTimeMillis()
+ println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms")
+ val socketOutputStream = socket.getOutputStream
+ val parts = 10
+ (0 until parts).foreach(i => {
+ val partStartTime = System.currentTimeMillis
+
+ val offset = (i * bytes.length / parts).toInt
+ val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length)
+ socketOutputStream.write(bytes, offset, len)
+ socketOutputStream.flush()
+ val partFinishTime = System.currentTimeMillis
+ println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms")
+ val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1)
+ Thread.sleep(sleepTime)
+ })
+
+ socketOutputStream.flush()
+ /*val socketInputStream = new DataInputStream(socket.getInputStream)*/
+ /*val reply = socketInputStream.readUTF()*/
+ val finishTime = System.currentTimeMillis()
+ println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]")
+ /*println("Received = " + reply)*/
+ targetTime = targetTime + intervalTime
+ val sleepTime = (targetTime - finishTime) + 10
+ if (sleepTime > 0) {
+ println("Sleeping for " + sleepTime + " ms")
+ Thread.sleep(sleepTime)
+ } else {
+ println("############################")
+ println("###### Skipping sleep ######")
+ println("############################")
+ }
+ }
+ } catch {
+ case e: Exception => println(e)
+ }
+ println("Stopped sending")
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 4) {
+ printUsage
+ }
+
+ val streamReceiverHost = args(0)
+ val streamReceiverPort = args(1).toInt
+ val sentenceFile = args(2)
+ val intervalTime = args(3).toLong
+ val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0
+
+ println("Reading the file " + sentenceFile)
+ val source = Source.fromFile(sentenceFile)
+ val lines = source.mkString.split ("\n")
+ source.close()
+
+ val numSentences = if (sentencesPerInterval <= 0) {
+ lines.length
+ } else {
+ sentencesPerInterval
+ }
+
+ println("Generating sentences")
+ val sentences: Array[String] = if (numSentences <= lines.length) {
+ lines.take(numSentences).toArray
+ } else {
+ (0 until numSentences).map(i => lines(i % lines.length)).toArray
+ }
+
+ println("Converting to byte array")
+ val byteStream = new ByteArrayOutputStream()
+ val stringDataStream = new DataOutputStream(byteStream)
+ /*stringDataStream.writeInt(sentences.size)*/
+ sentences.foreach(stringDataStream.writeUTF)
+ val bytes = byteStream.toByteArray()
+ stringDataStream.close()
+ println("Generated array of " + bytes.length + " bytes")
+
+ /*while(true) { */
+ sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime)
+ /*println("Sleeping for 5 seconds")*/
+ /*Thread.sleep(5000)*/
+ /*System.gc()*/
+ /*}*/
+ }
+}
+
+
+
diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala
new file mode 100644
index 0000000000..9c39ef3e12
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala
@@ -0,0 +1,244 @@
+package spark.streaming.util
+
+import spark.Logging
+
+import scala.util.Random
+import scala.io.Source
+import scala.collection.mutable.{ArrayBuffer, Queue}
+
+import java.net._
+import java.io._
+import java.nio._
+import java.nio.charset._
+import java.nio.channels._
+
+import it.unimi.dsi.fastutil.io._
+
+class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int)
+extends Logging {
+
+ class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4)
+ extends ConnectionHandler(host, port, true) {
+
+ val buffers = new ArrayBuffer[ByteBuffer]
+ val newBuffers = new Queue[ByteBuffer]
+ var activeKey: SelectionKey = null
+
+ def send(buffer: ByteBuffer) {
+ logDebug("Sending: " + buffer)
+ newBuffers.synchronized {
+ newBuffers.enqueue(buffer)
+ }
+ selector.wakeup()
+ buffer.synchronized {
+ buffer.wait()
+ }
+ }
+
+ override def ready(key: SelectionKey) {
+ logDebug("Ready")
+ activeKey = key
+ val channel = key.channel.asInstanceOf[SocketChannel]
+ channel.register(selector, SelectionKey.OP_WRITE)
+ generator.startSending()
+ }
+
+ override def preSelect() {
+ newBuffers.synchronized {
+ while(!newBuffers.isEmpty) {
+ val buffer = newBuffers.dequeue
+ buffers += buffer
+ logDebug("Added: " + buffer)
+ changeInterest(activeKey, SelectionKey.OP_WRITE)
+ }
+ }
+ }
+
+ override def write(key: SelectionKey) {
+ try {
+ /*while(true) {*/
+ val channel = key.channel.asInstanceOf[SocketChannel]
+ if (buffers.size > 0) {
+ val buffer = buffers(0)
+ val newBuffer = buffer.slice()
+ newBuffer.limit(math.min(newBuffer.remaining, 32768))
+ val bytesWritten = channel.write(newBuffer)
+ buffer.position(buffer.position + bytesWritten)
+ if (bytesWritten == 0) return
+ if (buffer.remaining == 0) {
+ buffers -= buffer
+ buffer.synchronized {
+ buffer.notify()
+ }
+ }
+ /*changeInterest(key, SelectionKey.OP_WRITE)*/
+ } else {
+ changeInterest(key, 0)
+ }
+ /*}*/
+ } catch {
+ case e: IOException => {
+ if (e.toString.contains("pipe") || e.toString.contains("reset")) {
+ logError("Connection broken")
+ } else {
+ logError("Connection error", e)
+ }
+ close(key)
+ }
+ }
+ }
+
+ override def close(key: SelectionKey) {
+ buffers.clear()
+ super.close(key)
+ }
+ }
+
+ initLogging()
+
+ val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this)
+ var sendingThread: Thread = null
+ var sendCount = 0
+ val sendBatches = 5
+
+ def run() {
+ logInfo("Connection handler started")
+ connectionHandler.start()
+ connectionHandler.join()
+ if (sendingThread != null && !sendingThread.isInterrupted) {
+ sendingThread.interrupt
+ }
+ logInfo("Connection handler stopped")
+ }
+
+ def startSending() {
+ sendingThread = new Thread() {
+ override def run() {
+ logInfo("STARTING TO SEND")
+ sendSentences()
+ logInfo("SENDING STOPPED AFTER " + sendCount)
+ connectionHandler.interrupt()
+ }
+ }
+ sendingThread.start()
+ }
+
+ def stopSending() {
+ sendingThread.interrupt()
+ }
+
+ def sendSentences() {
+ logInfo("Reading the file " + sentenceFile)
+ val source = Source.fromFile(sentenceFile)
+ val lines = source.mkString.split ("\n")
+ source.close()
+
+ val numSentences = if (sentencesPerInterval <= 0) {
+ lines.length
+ } else {
+ sentencesPerInterval
+ }
+
+ logInfo("Generating sentence buffer")
+ val sentences: Array[String] = if (numSentences <= lines.length) {
+ lines.take(numSentences).toArray
+ } else {
+ (0 until numSentences).map(i => lines(i % lines.length)).toArray
+ }
+
+ /*
+ val sentences: Array[String] = if (numSentences <= lines.length) {
+ lines.take((numSentences / sendBatches).toInt).toArray
+ } else {
+ (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray
+ }*/
+
+
+ val serializer = new spark.KryoSerializer().newInstance()
+ val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024)
+ serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close()
+ byteStream.trim()
+ val sentenceBuffer = ByteBuffer.wrap(byteStream.array)
+
+ logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort )
+ val currentTime = System.currentTimeMillis
+ var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration
+ Thread.sleep(targetTime - currentTime)
+
+ val totalBytes = sentenceBuffer.limit
+
+ while(true) {
+ val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches
+
+ val startTime = System.currentTimeMillis()
+ logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms")
+
+ (0 until batchesInCurrentInterval).foreach(i => {
+ try {
+ val position = (i * totalBytes / sendBatches).toInt
+ val limit = if (i == sendBatches - 1) {
+ totalBytes
+ } else {
+ ((i + 1) * totalBytes / sendBatches).toInt - 1
+ }
+
+ val partStartTime = System.currentTimeMillis
+ sentenceBuffer.limit(limit)
+ connectionHandler.send(sentenceBuffer)
+ val partFinishTime = System.currentTimeMillis
+ val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1)
+ Thread.sleep(sleepTime)
+
+ } catch {
+ case ie: InterruptedException => return
+ case e: Exception => e.printStackTrace()
+ }
+ })
+ sentenceBuffer.rewind()
+
+ val finishTime = System.currentTimeMillis()
+ /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/
+ targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0)
+
+ val sleepTime = (targetTime - finishTime) + 20
+ if (sleepTime > 0) {
+ logInfo("Sleeping for " + sleepTime + " ms")
+ Thread.sleep(sleepTime)
+ } else {
+ logInfo("###### Skipping sleep ######")
+ }
+ if (Thread.currentThread.isInterrupted) {
+ return
+ }
+ sendCount += 1
+ }
+ }
+}
+
+object TestGenerator4 {
+ def printUsage {
+ println("Usage: TestGenerator4 <target IP> <target port> <sentence file> <interval duration> [<sentences per second>]")
+ System.exit(0)
+ }
+
+ def main(args: Array[String]) {
+ println("GENERATOR STARTED")
+ if (args.length < 4) {
+ printUsage
+ }
+
+
+ val streamReceiverHost = args(0)
+ val streamReceiverPort = args(1).toInt
+ val sentenceFile = args(2)
+ val intervalDuration = args(3).toLong
+ val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0
+
+ while(true) {
+ val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval)
+ generator.run()
+ Thread.sleep(2000)
+ }
+ println("GENERATOR STOPPED")
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala
new file mode 100644
index 0000000000..f584f772bb
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala
@@ -0,0 +1,39 @@
+package spark.streaming.util
+
+import spark.streaming._
+import spark.Logging
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+
+sealed trait TestStreamCoordinatorMessage
+case class GetStreamDetails extends TestStreamCoordinatorMessage
+case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage
+case class TestStarted extends TestStreamCoordinatorMessage
+
+class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging {
+
+ var index = 0
+
+ initLogging()
+
+ logInfo("Created")
+
+ def receive = {
+ case TestStarted => {
+ sender ! "OK"
+ }
+
+ case GetStreamDetails => {
+ val streamDetail = if (index >= streamDetails.length) null else streamDetails(index)
+ sender ! GotStreamDetails(streamDetail._1, streamDetail._2)
+ index += 1
+ if (streamDetail != null) {
+ logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" )
+ }
+ }
+ }
+
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala
new file mode 100644
index 0000000000..80ad924dd8
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala
@@ -0,0 +1,421 @@
+package spark.streaming.util
+
+import spark._
+import spark.storage._
+import spark.util.AkkaUtils
+import spark.streaming._
+
+import scala.math._
+import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap}
+
+import akka.actor._
+import akka.actor.Actor
+import akka.dispatch._
+import akka.pattern.ask
+import akka.util.duration._
+
+import java.io.DataInputStream
+import java.io.BufferedInputStream
+import java.net.Socket
+import java.net.ServerSocket
+import java.util.LinkedHashMap
+
+import org.apache.hadoop.fs._
+import org.apache.hadoop.conf._
+import org.apache.hadoop.io._
+import org.apache.hadoop.mapred._
+import org.apache.hadoop.util._
+
+import spark.Utils
+
+
+class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager)
+extends Thread with Logging {
+
+
+ class DataHandler(
+ inputName: String,
+ longIntervalDuration: Time,
+ shortIntervalDuration: Time,
+ blockManager: BlockManager
+ )
+ extends Logging {
+
+ class Block(var id: String, var shortInterval: Interval) {
+ val data = ArrayBuffer[String]()
+ var pushed = false
+ def longInterval = getLongInterval(shortInterval)
+ def empty() = (data.size == 0)
+ def += (str: String) = (data += str)
+ override def toString() = "Block " + id
+ }
+
+ class Bucket(val longInterval: Interval) {
+ val blocks = new ArrayBuffer[Block]()
+ var filled = false
+ def += (block: Block) = blocks += block
+ def empty() = (blocks.size == 0)
+ def ready() = (filled && !blocks.exists(! _.pushed))
+ def blockIds() = blocks.map(_.id).toArray
+ override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]"
+ }
+
+ initLogging()
+
+ val shortIntervalDurationMillis = shortIntervalDuration.toLong
+ val longIntervalDurationMillis = longIntervalDuration.toLong
+
+ var currentBlock: Block = null
+ var currentBucket: Bucket = null
+
+ val blocksForPushing = new Queue[Block]()
+ val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket]
+
+ val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } }
+ val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
+
+ def start() {
+ blockUpdatingThread.start()
+ blockPushingThread.start()
+ }
+
+ def += (data: String) = addData(data)
+
+ def addData(data: String) {
+ if (currentBlock == null) {
+ updateCurrentBlock()
+ }
+ currentBlock.synchronized {
+ currentBlock += data
+ }
+ }
+
+ def getShortInterval(time: Time): Interval = {
+ val intervalBegin = time.floor(shortIntervalDuration)
+ Interval(intervalBegin, intervalBegin + shortIntervalDuration)
+ }
+
+ def getLongInterval(shortInterval: Interval): Interval = {
+ val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration)
+ Interval(intervalBegin, intervalBegin + longIntervalDuration)
+ }
+
+ def updateCurrentBlock() {
+ /*logInfo("Updating current block")*/
+ val currentTime = Time(System.currentTimeMillis)
+ val shortInterval = getShortInterval(currentTime)
+ val longInterval = getLongInterval(shortInterval)
+
+ def createBlock(reuseCurrentBlock: Boolean = false) {
+ val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size
+ if (!reuseCurrentBlock) {
+ val newBlock = new Block(newBlockId, shortInterval)
+ /*logInfo("Created " + currentBlock)*/
+ currentBlock = newBlock
+ } else {
+ currentBlock.shortInterval = shortInterval
+ currentBlock.id = newBlockId
+ }
+ }
+
+ def createBucket() {
+ val newBucket = new Bucket(longInterval)
+ buckets += ((longInterval, newBucket))
+ currentBucket = newBucket
+ /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/
+ }
+
+ if (currentBlock == null || currentBucket == null) {
+ createBucket()
+ currentBucket.synchronized {
+ createBlock()
+ }
+ return
+ }
+
+ currentBlock.synchronized {
+ var reuseCurrentBlock = false
+
+ if (shortInterval != currentBlock.shortInterval) {
+ if (!currentBlock.empty) {
+ blocksForPushing.synchronized {
+ blocksForPushing += currentBlock
+ blocksForPushing.notifyAll()
+ }
+ }
+
+ currentBucket.synchronized {
+ if (currentBlock.empty) {
+ reuseCurrentBlock = true
+ } else {
+ currentBucket += currentBlock
+ }
+
+ if (longInterval != currentBucket.longInterval) {
+ currentBucket.filled = true
+ if (currentBucket.ready) {
+ currentBucket.notifyAll()
+ }
+ createBucket()
+ }
+ }
+
+ createBlock(reuseCurrentBlock)
+ }
+ }
+ }
+
+ def pushBlock(block: Block) {
+ try{
+ if (blockManager != null) {
+ logInfo("Pushing block")
+ val startTime = System.currentTimeMillis
+
+ val bytes = blockManager.dataSerialize("rdd_", block.data.toIterator) // TODO: Will this be an RDD block?
+ val finishTime = System.currentTimeMillis
+ logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s")
+
+ blockManager.putBytes(block.id.toString, bytes, StorageLevel.MEMORY_AND_DISK_SER_2)
+ /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/
+ /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/
+ /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/
+ val finishTime1 = System.currentTimeMillis
+ logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s")
+ } else {
+ logWarning(block + " not put as block manager is null")
+ }
+ } catch {
+ case e: Exception => logError("Exception writing " + block + " to blockmanager" , e)
+ }
+ }
+
+ def getBucket(longInterval: Interval): Option[Bucket] = {
+ buckets.get(longInterval)
+ }
+
+ def clearBucket(longInterval: Interval) {
+ buckets.remove(longInterval)
+ }
+
+ def keepUpdatingCurrentBlock() {
+ logInfo("Thread to update current block started")
+ while(true) {
+ updateCurrentBlock()
+ val currentTimeMillis = System.currentTimeMillis
+ val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) *
+ shortIntervalDurationMillis - currentTimeMillis + 1
+ Thread.sleep(sleepTimeMillis)
+ }
+ }
+
+ def keepPushingBlocks() {
+ var loop = true
+ logInfo("Thread to push blocks started")
+ while(loop) {
+ val block = blocksForPushing.synchronized {
+ if (blocksForPushing.size == 0) {
+ blocksForPushing.wait()
+ }
+ blocksForPushing.dequeue
+ }
+ pushBlock(block)
+ block.pushed = true
+ block.data.clear()
+
+ val bucket = buckets(block.longInterval)
+ bucket.synchronized {
+ if (bucket.ready) {
+ bucket.notifyAll()
+ }
+ }
+ }
+ }
+ }
+
+
+ class ConnectionListener(port: Int, dataHandler: DataHandler)
+ extends Thread with Logging {
+ initLogging()
+ override def run {
+ try {
+ val listener = new ServerSocket(port)
+ logInfo("Listening on port " + port)
+ while (true) {
+ new ConnectionHandler(listener.accept(), dataHandler).start();
+ }
+ listener.close()
+ } catch {
+ case e: Exception => logError("", e);
+ }
+ }
+ }
+
+ class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging {
+ initLogging()
+ override def run {
+ logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort)
+ val bytes = new Array[Byte](100 * 1024 * 1024)
+ try {
+
+ val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024))
+ /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/
+ var str: String = null
+ str = inputStream.readUTF
+ while(str != null) {
+ dataHandler += str
+ str = inputStream.readUTF()
+ }
+
+ /*
+ var loop = true
+ while(loop) {
+ val numRead = inputStream.read(bytes)
+ if (numRead < 0) {
+ loop = false
+ }
+ inbox += ((LongTime(SystemTime.currentTimeMillis), "test"))
+ }*/
+
+ inputStream.close()
+ } catch {
+ case e => logError("Error receiving data", e)
+ }
+ socket.close()
+ }
+ }
+
+ initLogging()
+
+ val masterHost = System.getProperty("spark.master.host")
+ val masterPort = System.getProperty("spark.master.port").toInt
+
+ val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort)
+ val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler")
+ val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator")
+
+ logInfo("Getting stream details from master " + masterHost + ":" + masterPort)
+
+ val timeout = 50 millis
+
+ var started = false
+ while (!started) {
+ askActor[String](testStreamCoordinator, TestStarted) match {
+ case Some(str) => {
+ started = true
+ logInfo("TestStreamCoordinator started")
+ }
+ case None => {
+ logInfo("TestStreamCoordinator not started yet")
+ Thread.sleep(200)
+ }
+ }
+ }
+
+ val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match {
+ case Some(details) => details
+ case None => throw new Exception("Could not get stream details")
+ }
+ logInfo("Stream details received: " + streamDetails)
+
+ val inputName = streamDetails.name
+ val intervalDurationMillis = streamDetails.duration
+ val intervalDuration = Time(intervalDurationMillis)
+
+ val dataHandler = new DataHandler(
+ inputName,
+ intervalDuration,
+ Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS),
+ blockManager)
+
+ val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler)
+
+ // Send a message to an actor and return an option with its reply, or None if this times out
+ def askActor[T](actor: ActorRef, message: Any): Option[T] = {
+ try {
+ val future = actor.ask(message)(timeout)
+ return Some(Await.result(future, timeout).asInstanceOf[T])
+ } catch {
+ case e: Exception =>
+ logInfo("Error communicating with " + actor, e)
+ return None
+ }
+ }
+
+ override def run() {
+ connListener.start()
+ dataHandler.start()
+
+ var interval = Interval.currentInterval(intervalDuration)
+ var dataStarted = false
+
+ while(true) {
+ waitFor(interval.endTime)
+ logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)
+ dataHandler.getBucket(interval) match {
+ case Some(bucket) => {
+ logInfo("Found " + bucket + " for " + interval)
+ bucket.synchronized {
+ if (!bucket.ready) {
+ logInfo("Waiting for " + bucket)
+ bucket.wait()
+ logInfo("Wait over for " + bucket)
+ }
+ if (dataStarted || !bucket.empty) {
+ logInfo("Notifying " + bucket)
+ notifyScheduler(interval, bucket.blockIds)
+ dataStarted = true
+ }
+ bucket.blocks.clear()
+ dataHandler.clearBucket(interval)
+ }
+ }
+ case None => {
+ logInfo("Found none for " + interval)
+ if (dataStarted) {
+ logInfo("Notifying none")
+ notifyScheduler(interval, Array[String]())
+ }
+ }
+ }
+ interval = interval.next
+ }
+ }
+
+ def waitFor(time: Time) {
+ val currentTimeMillis = System.currentTimeMillis
+ val targetTimeMillis = time.milliseconds
+ if (currentTimeMillis < targetTimeMillis) {
+ val sleepTime = (targetTimeMillis - currentTimeMillis)
+ Thread.sleep(sleepTime + 1)
+ }
+ }
+
+ def notifyScheduler(interval: Interval, blockIds: Array[String]) {
+ try {
+ sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray)
+ val time = interval.endTime
+ val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0
+ logInfo("Pushing delay for " + time + " is " + delay + " s")
+ } catch {
+ case _ => logError("Exception notifying scheduler at interval " + interval)
+ }
+ }
+}
+
+object TestStreamReceiver3 {
+
+ val PORT = 9999
+ val SHORT_INTERVAL_MILLIS = 100
+
+ def main(args: Array[String]) {
+ System.setProperty("spark.master.host", Utils.localHostName)
+ System.setProperty("spark.master.port", "7078")
+ val details = Array(("Sentences", 2000L))
+ val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078)
+ actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator")
+ new TestStreamReceiver3(actorSystem, null).start()
+ }
+}
+
+
+
diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala
new file mode 100644
index 0000000000..31754870dd
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala
@@ -0,0 +1,374 @@
+package spark.streaming.util
+
+import spark.streaming._
+import spark._
+import spark.storage._
+import spark.util.AkkaUtils
+
+import scala.math._
+import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap}
+
+import java.io._
+import java.nio._
+import java.nio.charset._
+import java.nio.channels._
+import java.util.concurrent.Executors
+
+import akka.actor._
+import akka.actor.Actor
+import akka.dispatch._
+import akka.pattern.ask
+import akka.util.duration._
+
+class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager)
+extends Thread with Logging {
+
+ class DataHandler(
+ inputName: String,
+ longIntervalDuration: Time,
+ shortIntervalDuration: Time,
+ blockManager: BlockManager
+ )
+ extends Logging {
+
+ class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) {
+ var pushed = false
+ def longInterval = getLongInterval(shortInterval)
+ override def toString() = "Block " + id
+ }
+
+ class Bucket(val longInterval: Interval) {
+ val blocks = new ArrayBuffer[Block]()
+ var filled = false
+ def += (block: Block) = blocks += block
+ def empty() = (blocks.size == 0)
+ def ready() = (filled && !blocks.exists(! _.pushed))
+ def blockIds() = blocks.map(_.id).toArray
+ override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]"
+ }
+
+ initLogging()
+
+ val syncOnLastShortInterval = true
+
+ val shortIntervalDurationMillis = shortIntervalDuration.milliseconds
+ val longIntervalDurationMillis = longIntervalDuration.milliseconds
+
+ val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024)
+ var currentShortInterval = Interval.currentInterval(shortIntervalDuration)
+
+ val blocksForPushing = new Queue[Block]()
+ val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket]
+
+ val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } }
+ val blockPushingExecutor = Executors.newFixedThreadPool(5)
+
+
+ def start() {
+ buffer.clear()
+ if (buffer.remaining == 0) {
+ throw new Exception("Buffer initialization error")
+ }
+ bufferProcessingThread.start()
+ }
+
+ def readDataToBuffer(func: ByteBuffer => Int): Int = {
+ buffer.synchronized {
+ if (buffer.remaining == 0) {
+ logInfo("Received first data for interval " + currentShortInterval)
+ }
+ func(buffer)
+ }
+ }
+
+ def getLongInterval(shortInterval: Interval): Interval = {
+ val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration)
+ Interval(intervalBegin, intervalBegin + longIntervalDuration)
+ }
+
+ def processBuffer() {
+
+ def readInt(buffer: ByteBuffer): Int = {
+ var offset = 0
+ var result = 0
+ while (offset < 32) {
+ val b = buffer.get()
+ result |= ((b & 0x7F) << offset)
+ if ((b & 0x80) == 0) {
+ return result
+ }
+ offset += 7
+ }
+ throw new Exception("Malformed zigzag-encoded integer")
+ }
+
+ val currentLongInterval = getLongInterval(currentShortInterval)
+ val startTime = System.currentTimeMillis
+ val newBuffer: ByteBuffer = buffer.synchronized {
+ buffer.flip()
+ if (buffer.remaining == 0) {
+ buffer.clear()
+ null
+ } else {
+ logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms")
+ val startTime1 = System.currentTimeMillis
+ var loop = true
+ var count = 0
+ while(loop) {
+ buffer.mark()
+ try {
+ val len = readInt(buffer)
+ buffer.position(buffer.position + len)
+ count += 1
+ } catch {
+ case e: Exception => {
+ buffer.reset()
+ loop = false
+ }
+ }
+ }
+ val bytesToCopy = buffer.position
+ val newBuf = ByteBuffer.allocate(bytesToCopy)
+ buffer.position(0)
+ newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer])
+ newBuf.flip()
+ buffer.position(bytesToCopy)
+ buffer.compact()
+ newBuf
+ }
+ }
+
+ if (newBuffer != null) {
+ val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval))
+ bucket.synchronized {
+ val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString
+ val newBlock = new Block(newBlockId, currentShortInterval, newBuffer)
+ if (syncOnLastShortInterval) {
+ bucket += newBlock
+ }
+ logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" )
+ blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } })
+ }
+ }
+
+ val newShortInterval = Interval.currentInterval(shortIntervalDuration)
+ val newLongInterval = getLongInterval(newShortInterval)
+
+ if (newLongInterval != currentLongInterval) {
+ buckets.get(currentLongInterval) match {
+ case Some(bucket) => {
+ bucket.synchronized {
+ bucket.filled = true
+ if (bucket.ready) {
+ bucket.notifyAll()
+ }
+ }
+ }
+ case None =>
+ }
+ buckets += ((newLongInterval, new Bucket(newLongInterval)))
+ }
+
+ currentShortInterval = newShortInterval
+ }
+
+ def pushBlock(block: Block) {
+ try{
+ if (blockManager != null) {
+ val startTime = System.currentTimeMillis
+ logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms")
+ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/
+ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/
+ blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2)
+ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/
+ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/
+ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/
+ val finishTime = System.currentTimeMillis
+ logInfo(block + " put delay is " + (finishTime - startTime) + " ms")
+ } else {
+ logWarning(block + " not put as block manager is null")
+ }
+ } catch {
+ case e: Exception => logError("Exception writing " + block + " to blockmanager" , e)
+ }
+ }
+
+ def getBucket(longInterval: Interval): Option[Bucket] = {
+ buckets.get(longInterval)
+ }
+
+ def clearBucket(longInterval: Interval) {
+ buckets.remove(longInterval)
+ }
+
+ def keepProcessingBuffers() {
+ logInfo("Thread to process buffers started")
+ while(true) {
+ processBuffer()
+ val currentTimeMillis = System.currentTimeMillis
+ val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) *
+ shortIntervalDurationMillis - currentTimeMillis + 1
+ Thread.sleep(sleepTimeMillis)
+ }
+ }
+
+ def pushAndNotifyBlock(block: Block) {
+ pushBlock(block)
+ block.pushed = true
+ val bucket = if (syncOnLastShortInterval) {
+ buckets(block.longInterval)
+ } else {
+ var longInterval = block.longInterval
+ while(!buckets.contains(longInterval)) {
+ logWarning("Skipping bucket of " + longInterval + " for " + block)
+ longInterval = longInterval.next
+ }
+ val chosenBucket = buckets(longInterval)
+ logDebug("Choosing bucket of " + longInterval + " for " + block)
+ chosenBucket += block
+ chosenBucket
+ }
+
+ bucket.synchronized {
+ if (bucket.ready) {
+ bucket.notifyAll()
+ }
+ }
+
+ }
+ }
+
+
+ class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler)
+ extends ConnectionHandler(host, port, false) {
+
+ override def ready(key: SelectionKey) {
+ changeInterest(key, SelectionKey.OP_READ)
+ }
+
+ override def read(key: SelectionKey) {
+ try {
+ val channel = key.channel.asInstanceOf[SocketChannel]
+ val bytesRead = dataHandler.readDataToBuffer(channel.read)
+ if (bytesRead < 0) {
+ close(key)
+ }
+ } catch {
+ case e: IOException => {
+ logError("Error reading", e)
+ close(key)
+ }
+ }
+ }
+ }
+
+ initLogging()
+
+ val masterHost = System.getProperty("spark.master.host", "localhost")
+ val masterPort = System.getProperty("spark.master.port", "7078").toInt
+
+ val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort)
+ val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler")
+ val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator")
+
+ logInfo("Getting stream details from master " + masterHost + ":" + masterPort)
+
+ val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match {
+ case Some(details) => details
+ case None => throw new Exception("Could not get stream details")
+ }
+ logInfo("Stream details received: " + streamDetails)
+
+ val inputName = streamDetails.name
+ val intervalDurationMillis = streamDetails.duration
+ val intervalDuration = Milliseconds(intervalDurationMillis)
+ val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt)
+
+ val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager)
+ val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler)
+
+ val timeout = 100 millis
+
+ // Send a message to an actor and return an option with its reply, or None if this times out
+ def askActor[T](actor: ActorRef, message: Any): Option[T] = {
+ try {
+ val future = actor.ask(message)(timeout)
+ return Some(Await.result(future, timeout).asInstanceOf[T])
+ } catch {
+ case e: Exception =>
+ logInfo("Error communicating with " + actor, e)
+ return None
+ }
+ }
+
+ override def run() {
+ connectionHandler.start()
+ dataHandler.start()
+
+ var interval = Interval.currentInterval(intervalDuration)
+ var dataStarted = false
+
+
+ while(true) {
+ waitFor(interval.endTime)
+ /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/
+ dataHandler.getBucket(interval) match {
+ case Some(bucket) => {
+ logDebug("Found " + bucket + " for " + interval)
+ bucket.synchronized {
+ if (!bucket.ready) {
+ logDebug("Waiting for " + bucket)
+ bucket.wait()
+ logDebug("Wait over for " + bucket)
+ }
+ if (dataStarted || !bucket.empty) {
+ logDebug("Notifying " + bucket)
+ notifyScheduler(interval, bucket.blockIds)
+ dataStarted = true
+ }
+ bucket.blocks.clear()
+ dataHandler.clearBucket(interval)
+ }
+ }
+ case None => {
+ logDebug("Found none for " + interval)
+ if (dataStarted) {
+ logDebug("Notifying none")
+ notifyScheduler(interval, Array[String]())
+ }
+ }
+ }
+ interval = interval.next
+ }
+ }
+
+ def waitFor(time: Time) {
+ val currentTimeMillis = System.currentTimeMillis
+ val targetTimeMillis = time.milliseconds
+ if (currentTimeMillis < targetTimeMillis) {
+ val sleepTime = (targetTimeMillis - currentTimeMillis)
+ Thread.sleep(sleepTime + 1)
+ }
+ }
+
+ def notifyScheduler(interval: Interval, blockIds: Array[String]) {
+ try {
+ sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray)
+ val time = interval.endTime
+ val delay = (System.currentTimeMillis - time.milliseconds)
+ logInfo("Notification delay for " + time + " is " + delay + " ms")
+ } catch {
+ case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e)
+ }
+ }
+}
+
+
+object TestStreamReceiver4 {
+ def main(args: Array[String]) {
+ val details = Array(("Sentences", 2000L))
+ val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078)
+ actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator")
+ new TestStreamReceiver4(actorSystem, null).start()
+ }
+}
diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..02fe16866e
--- /dev/null
+++ b/streaming/src/test/resources/log4j.properties
@@ -0,0 +1,8 @@
+# Set everything to be logged to the console
+log4j.rootCategory=WARN, console
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
new file mode 100644
index 0000000000..d0aaac0f2e
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -0,0 +1,213 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+import scala.runtime.RichInt
+import util.ManualClock
+
+class BasicOperationsSuite extends TestSuiteBase {
+
+ override def framework() = "BasicOperationsSuite"
+
+ test("map") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.map(_.toString),
+ input.map(_.map(_.toString))
+ )
+ }
+
+ test("flatmap") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)),
+ input.map(_.flatMap(x => Array(x, x * 2)))
+ )
+ }
+
+ test("filter") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.filter(x => (x % 2 == 0)),
+ input.map(_.filter(x => (x % 2 == 0)))
+ )
+ }
+
+ test("glom") {
+ assert(numInputPartitions === 2, "Number of input partitions has been changed from 2")
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ val output = Seq(
+ Seq( Seq(1, 2), Seq(3, 4) ),
+ Seq( Seq(5, 6), Seq(7, 8) ),
+ Seq( Seq(9, 10), Seq(11, 12) )
+ )
+ val operation = (r: DStream[Int]) => r.glom().map(_.toSeq)
+ testOperation(input, operation, output)
+ }
+
+ test("mapPartitions") {
+ assert(numInputPartitions === 2, "Number of input partitions has been changed from 2")
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23))
+ val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _)))
+ testOperation(input, operation, output, true)
+ }
+
+ test("groupByKey") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(),
+ Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ),
+ true
+ )
+ }
+
+ test("reduceByKey") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
+ Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
+ true
+ )
+ }
+
+ test("reduce") {
+ testOperation(
+ Seq(1 to 4, 5 to 8, 9 to 12),
+ (s: DStream[Int]) => s.reduce(_ + _),
+ Seq(Seq(10), Seq(26), Seq(42))
+ )
+ }
+
+ test("mapValues") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10),
+ Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ),
+ true
+ )
+ }
+
+ test("flatMapValues") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)),
+ Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ),
+ true
+ )
+ }
+
+ test("cogroup") {
+ val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() )
+ val outputData = Seq(
+ Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ),
+ Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ),
+ Seq( ("", (Seq(1), Seq())) ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("join") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (1, "x")), ("b", (1, "x")) ),
+ Seq( ("", (1, "x")) ),
+ Seq( ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x,1)).join(s2.map(x => (x,"x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("updateStateByKey") {
+ val inputData =
+ Seq(
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData =
+ Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1))
+ )
+
+ val updateStateOperation = (s: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self))
+ }
+
+ testOperation(inputData, updateStateOperation, outputData, true)
+ }
+
+ test("forgetting of RDDs - map and window operations") {
+ assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second")
+
+ val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq
+ val rememberDuration = Seconds(3)
+
+ assert(input.size === 10, "Number of inputs have changed")
+
+ def operation(s: DStream[Int]): DStream[(Int, Int)] = {
+ s.map(x => (x % 10, 1))
+ .window(Seconds(2), Seconds(1))
+ .window(Seconds(4), Seconds(2))
+ }
+
+ val ssc = setupStreams(input, operation _)
+ ssc.setRememberDuration(rememberDuration)
+ runStreams[(Int, Int)](ssc, input.size, input.size / 2)
+
+ val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head
+ val windowedStream1 = windowedStream2.dependencies.head
+ val mappedStream = windowedStream1.dependencies.head
+
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ assert(clock.time === Seconds(10).milliseconds)
+
+ // IDEALLY
+ // WindowedStream2 should remember till 7 seconds: 10, 8,
+ // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5
+ // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3,
+
+ // IN THIS TEST
+ // WindowedStream2 should remember till 7 seconds: 10, 8,
+ // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4
+ // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2
+
+ // WindowedStream2
+ assert(windowedStream2.generatedRDDs.contains(Seconds(10)))
+ assert(windowedStream2.generatedRDDs.contains(Seconds(8)))
+ assert(!windowedStream2.generatedRDDs.contains(Seconds(6)))
+
+ // WindowedStream1
+ assert(windowedStream1.generatedRDDs.contains(Seconds(10)))
+ assert(windowedStream1.generatedRDDs.contains(Seconds(4)))
+ assert(!windowedStream1.generatedRDDs.contains(Seconds(3)))
+
+ // MappedStream
+ assert(mappedStream.generatedRDDs.contains(Seconds(10)))
+ assert(mappedStream.generatedRDDs.contains(Seconds(2)))
+ assert(!mappedStream.generatedRDDs.contains(Seconds(1)))
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
new file mode 100644
index 0000000000..6dcedcf463
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -0,0 +1,53 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+import java.io.File
+
+class CheckpointSuite extends TestSuiteBase {
+
+ override def framework() = "CheckpointSuite"
+
+ override def checkpointFile() = "checkpoint"
+
+ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean = false
+ ) {
+
+ // Current code assumes that:
+ // number of inputs = number of outputs = number of batches to be run
+
+ val totalNumBatches = input.size
+ val initialNumBatches = input.size / 2
+ val nextNumBatches = totalNumBatches - initialNumBatches
+ val initialNumExpectedOutputs = initialNumBatches
+
+ // Do half the computation (half the number of batches), create checkpoint file and quit
+ val ssc = setupStreams[U, V](input, operation)
+ val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
+ verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet)
+ Thread.sleep(1000)
+
+ // Restart and complete the computation from checkpoint file
+ val sscNew = new StreamingContext(checkpointFile)
+ sscNew.setCheckpointDetails(null, null)
+ val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size)
+ verifyOutput[V](outputNew, expectedOutput, useSet)
+
+ new File(checkpointFile).delete()
+ new File(checkpointFile + ".bk").delete()
+ new File("." + checkpointFile + ".crc").delete()
+ new File("." + checkpointFile + ".bk.crc").delete()
+ }
+
+ test("simple per-batch operation") {
+ testCheckpointedOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
+ Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
+ true
+ )
+ }
+} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
new file mode 100644
index 0000000000..f81ab2607f
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -0,0 +1,114 @@
+package spark.streaming
+
+import java.net.{SocketException, Socket, ServerSocket}
+import java.io.{BufferedWriter, OutputStreamWriter}
+import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
+import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import util.ManualClock
+import spark.storage.StorageLevel
+import spark.Logging
+
+
+class InputStreamsSuite extends TestSuiteBase {
+
+ test("network input stream") {
+ val serverPort = 9999
+ val server = new TestServer(9999)
+ server.start()
+ val ssc = new StreamingContext(master, framework)
+ ssc.setBatchDuration(batchDuration)
+
+ val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK)
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]]
+ val outputStream = new TestOutputStream(networkStream, outputBuffer)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3)
+ val expectedOutput = input.map(_.toString)
+ for (i <- 0 until input.size) {
+ server.send(input(i).toString + "\n")
+ Thread.sleep(1000)
+ clock.addToTime(1000)
+ }
+ val startTime = System.currentTimeMillis()
+ while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size)
+ Thread.sleep(100)
+ }
+ Thread.sleep(5000)
+ val timeTaken = System.currentTimeMillis() - startTime
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+ logInfo("Stopping server")
+ server.stop()
+ logInfo("Stopping context")
+ ssc.stop()
+
+ assert(outputBuffer.size === expectedOutput.size)
+ for (i <- 0 until outputBuffer.size) {
+ assert(outputBuffer(i).size === 1)
+ assert(outputBuffer(i).head === expectedOutput(i))
+ }
+ }
+}
+
+
+class TestServer(port: Int) extends Logging {
+
+ val queue = new ArrayBlockingQueue[String](100)
+
+ val serverSocket = new ServerSocket(port)
+
+ val servingThread = new Thread() {
+ override def run() {
+ try {
+ while(true) {
+ logInfo("Accepting connections on port " + port)
+ val clientSocket = serverSocket.accept()
+ logInfo("New connection")
+ try {
+ clientSocket.setTcpNoDelay(true)
+ val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream))
+
+ while(clientSocket.isConnected) {
+ val msg = queue.poll(100, TimeUnit.MILLISECONDS)
+ if (msg != null) {
+ outputStream.write(msg)
+ outputStream.flush()
+ logInfo("Message '" + msg + "' sent")
+ }
+ }
+ } catch {
+ case e: SocketException => println(e)
+ } finally {
+ logInfo("Connection closed")
+ if (!clientSocket.isClosed) clientSocket.close()
+ }
+ }
+ } catch {
+ case ie: InterruptedException =>
+
+ } finally {
+ serverSocket.close()
+ }
+ }
+ }
+
+ def start() { servingThread.start() }
+
+ def send(msg: String) { queue.add(msg) }
+
+ def stop() { servingThread.interrupt() }
+}
+
+object TestServer {
+ def main(args: Array[String]) {
+ val s = new TestServer(9999)
+ s.start()
+ while(true) {
+ Thread.sleep(1000)
+ s.send("hello")
+ }
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
new file mode 100644
index 0000000000..c1b7772e7b
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -0,0 +1,216 @@
+package spark.streaming
+
+import spark.{RDD, Logging}
+import util.ManualClock
+import collection.mutable.ArrayBuffer
+import org.scalatest.FunSuite
+import collection.mutable.SynchronizedBuffer
+
+class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
+ extends InputDStream[T](ssc_) {
+ var currentIndex = 0
+
+ def start() {}
+
+ def stop() {}
+
+ def compute(validTime: Time): Option[RDD[T]] = {
+ logInfo("Computing RDD for time " + validTime)
+ val rdd = if (currentIndex < input.size) {
+ ssc.sc.makeRDD(input(currentIndex), numPartitions)
+ } else {
+ ssc.sc.makeRDD(Seq[T](), numPartitions)
+ }
+ logInfo("Created RDD " + rdd.id)
+ currentIndex += 1
+ Some(rdd)
+ }
+}
+
+class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
+ extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ val collected = rdd.collect()
+ output += collected
+ })
+
+trait TestSuiteBase extends FunSuite with Logging {
+
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
+ def framework() = "TestSuiteBase"
+
+ def master() = "local[2]"
+
+ def batchDuration() = Seconds(1)
+
+ def checkpointFile() = null.asInstanceOf[String]
+
+ def checkpointInterval() = batchDuration
+
+ def numInputPartitions() = 2
+
+ def maxWaitTimeMillis() = 10000
+
+ def setupStreams[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V]
+ ): StreamingContext = {
+
+ // Create StreamingContext
+ val ssc = new StreamingContext(master, framework)
+ ssc.setBatchDuration(batchDuration)
+ if (checkpointFile != null) {
+ ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ }
+
+ // Setup the stream computation
+ val inputStream = new TestInputStream(ssc, input, numInputPartitions)
+ val operatedStream = operation(inputStream)
+ val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
+ ssc.registerInputStream(inputStream)
+ ssc.registerOutputStream(outputStream)
+ ssc
+ }
+
+ def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W]
+ ): StreamingContext = {
+
+ // Create StreamingContext
+ val ssc = new StreamingContext(master, framework)
+ ssc.setBatchDuration(batchDuration)
+ if (checkpointFile != null) {
+ ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ }
+
+ // Setup the stream computation
+ val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions)
+ val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions)
+ val operatedStream = operation(inputStream1, inputStream2)
+ val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]])
+ ssc.registerInputStream(inputStream1)
+ ssc.registerInputStream(inputStream2)
+ ssc.registerOutputStream(outputStream)
+ ssc
+ }
+
+
+ def runStreams[V: ClassManifest](
+ ssc: StreamingContext,
+ numBatches: Int,
+ numExpectedOutput: Int
+ ): Seq[Seq[V]] = {
+
+ assert(numBatches > 0, "Number of batches to run stream computation is zero")
+ assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
+ logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
+
+ // Get the output buffer
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ val output = outputStream.output
+
+ try {
+ // Start computation
+ ssc.start()
+
+ // Advance manual clock
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ logInfo("Manual clock before advancing = " + clock.time)
+ clock.addToTime(numBatches * batchDuration.milliseconds)
+ logInfo("Manual clock after advancing = " + clock.time)
+
+ // Wait until expected number of output items have been generated
+ val startTime = System.currentTimeMillis()
+ while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput)
+ Thread.sleep(100)
+ }
+ val timeTaken = System.currentTimeMillis() - startTime
+
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+ assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
+
+ Thread.sleep(500) // Give some time for the forgetting old RDDs to complete
+ } catch {
+ case e: Exception => e.printStackTrace(); throw e;
+ } finally {
+ ssc.stop()
+ }
+
+ output
+ }
+
+ def verifyOutput[V: ClassManifest](
+ output: Seq[Seq[V]],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean
+ ) {
+ logInfo("--------------------------------")
+ logInfo("output.size = " + output.size)
+ logInfo("output")
+ output.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Match the output with the expected output
+ assert(output.size === expectedOutput.size, "Number of outputs do not match")
+ for (i <- 0 until output.size) {
+ if (useSet) {
+ assert(output(i).toSet === expectedOutput(i).toSet)
+ } else {
+ assert(output(i).toList === expectedOutput(i).toList)
+ }
+ }
+ logInfo("Output verified successfully")
+ }
+
+ def testOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean = false
+ ) {
+ testOperation[U, V](input, operation, expectedOutput, -1, useSet)
+ }
+
+ def testOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ numBatches: Int,
+ useSet: Boolean
+ ) {
+ val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
+ val ssc = setupStreams[U, V](input, operation)
+ val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[V](output, expectedOutput, useSet)
+ }
+
+ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W],
+ expectedOutput: Seq[Seq[W]],
+ useSet: Boolean
+ ) {
+ testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet)
+ }
+
+ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W],
+ expectedOutput: Seq[Seq[W]],
+ numBatches: Int,
+ useSet: Boolean
+ ) {
+ val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
+ val ssc = setupStreams[U, V, W](input1, input2, operation)
+ val output = runStreams[W](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[W](output, expectedOutput, useSet)
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
new file mode 100644
index 0000000000..90d67844bb
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -0,0 +1,188 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+
+class WindowOperationsSuite extends TestSuiteBase {
+
+ override def framework() = "WindowOperationsSuite"
+
+ override def maxWaitTimeMillis() = 20000
+
+ val largerSlideInput = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2)), // 1st window from here
+ Seq(("a", 3)),
+ Seq(("a", 4)), // 2nd window from here
+ Seq(("a", 5)),
+ Seq(("a", 6)), // 3rd window from here
+ Seq(),
+ Seq() // 4th window from here
+ )
+
+ val largerSlideOutput = Seq(
+ Seq(("a", 3)),
+ Seq(("a", 10)),
+ Seq(("a", 18)),
+ Seq(("a", 11))
+ )
+
+
+ val bigInput = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1), ("b", 1), ("c", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1)),
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1), ("b", 1), ("c", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1)),
+ Seq()
+ )
+
+ val bigOutput = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 1)),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 1))
+ )
+
+ /*
+ The output of the reduceByKeyAndWindow with inverse reduce function is
+ difference from the naive reduceByKeyAndWindow. Even if the count of a
+ particular key is 0, the key does not get eliminated from the RDDs of
+ ReducedWindowedDStream. This causes the number of keys in these RDDs to
+ increase forever. A more generalized version that allows elimination of
+ keys should be considered.
+ */
+ val bigOutputInv = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1), ("c", 0)),
+ Seq(("a", 1), ("b", 0), ("c", 0)),
+ Seq(("a", 1), ("b", 0), ("c", 0)),
+ Seq(("a", 2), ("b", 1), ("c", 0)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1), ("c", 0)),
+ Seq(("a", 1), ("b", 0), ("c", 0))
+ )
+
+ def testReduceByKeyAndWindow(
+ name: String,
+ input: Seq[Seq[(String, Int)]],
+ expectedOutput: Seq[Seq[(String, Int)]],
+ windowTime: Time = batchDuration * 2,
+ slideTime: Time = batchDuration
+ ) {
+ test("reduceByKeyAndWindow - " + name) {
+ val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist()
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+ }
+
+ def testReduceByKeyAndWindowInv(
+ name: String,
+ input: Seq[Seq[(String, Int)]],
+ expectedOutput: Seq[Seq[(String, Int)]],
+ windowTime: Time = batchDuration * 2,
+ slideTime: Time = batchDuration
+ ) {
+ test("reduceByKeyAndWindowInv - " + name) {
+ val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist()
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+ }
+
+
+ // Testing naive reduceByKeyAndWindow (without invertible function)
+
+ testReduceByKeyAndWindow(
+ "basic reduction",
+ Seq( Seq(("a", 1), ("a", 3)) ),
+ Seq( Seq(("a", 4)) )
+ )
+
+ testReduceByKeyAndWindow(
+ "key already in window and new value added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)) )
+ )
+
+ testReduceByKeyAndWindow(
+ "new key added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) )
+ )
+
+ testReduceByKeyAndWindow(
+ "key removed from window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() )
+ )
+
+ testReduceByKeyAndWindow(
+ "larger slide time",
+ largerSlideInput,
+ largerSlideOutput,
+ Seconds(4),
+ Seconds(2)
+ )
+
+ testReduceByKeyAndWindow("big test", bigInput, bigOutput)
+
+
+ // Testing reduceByKeyAndWindow (with invertible reduce function)
+
+ testReduceByKeyAndWindowInv(
+ "basic reduction",
+ Seq(Seq(("a", 1), ("a", 3)) ),
+ Seq(Seq(("a", 4)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "key already in window and new value added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "new key added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "key removed from window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "larger slide time",
+ largerSlideInput,
+ largerSlideOutput,
+ Seconds(4),
+ Seconds(2)
+ )
+
+ testReduceByKeyAndWindowInv("big test", bigInput, bigOutputInv)
+}