diff options
59 files changed, 6001 insertions, 5 deletions
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d26cccbfe1..0d37075ef3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -58,10 +58,10 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend * @param environment Environment variables to set on worker nodes. */ class SparkContext( - master: String, - jobName: String, + val master: String, + val jobName: String, val sparkHome: String, - jars: Seq[String], + val jars: Seq[String], environment: Map[String, String]) extends Logging { @@ -595,6 +595,11 @@ class SparkContext( * various Spark features. */ object SparkContext { + + // TODO: temporary hack for using HDFS as input in streaing + var inputFile: String = null + var idealPartitions: Int = 1 + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 567c4b1475..1bdde25896 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -247,7 +247,7 @@ private object Utils extends Logging { * millisecond. */ def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms " + return " " + (System.currentTimeMillis - startTimeMs) + " ms" } /** diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..d11ed163ce --- /dev/null +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,56 @@ +package spark.util + +import java.io.OutputStream + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + var lastSyncTime = System.nanoTime() + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + override def write(bytes: Array[Byte], offset: Int, length: Int) { + val CHUNK_SIZE = 8192 + var pos = 0 + while (pos < length) { + val writeSize = math.min(length - pos, CHUNK_SIZE) + waitToWrite(writeSize) + out.write(bytes, offset + pos, writeSize) + pos += writeSize + } + } + + def waitToWrite(numBytes: Int) { + while (true) { + val now = System.nanoTime() + val elapsed = math.max(now - lastSyncTime, 1) + val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9) + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + (1e10).toLong) { + // Ten seconds have passed since lastSyncTime; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + return + } else { + Thread.sleep(5) + } + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } +}
\ No newline at end of file diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2f67bb9921..688bb16a03 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -17,7 +17,7 @@ object SparkBuild extends Build { //val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1" //val HADOOP_MAJOR_VERSION = "2" - lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel) + lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming) lazy val core = Project("core", file("core"), settings = coreSettings) @@ -27,6 +27,8 @@ object SparkBuild extends Build { lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn (core) + lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core) + // A configuration to set an alternative publishLocalConfiguration lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -153,6 +155,10 @@ object SparkBuild extends Build { def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") + def streamingSettings = sharedSettings ++ Seq( + name := "spark-streaming" + ) ++ assemblySettings ++ extraAssemblySettings + def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -63,6 +63,7 @@ CORE_DIR="$FWDIR/core" REPL_DIR="$FWDIR/repl" EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" +STREAMING_DIR="$FWDIR/streaming" # Build up classpath CLASSPATH="$SPARK_CLASSPATH" @@ -74,6 +75,7 @@ fi CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $FWDIR/lib_managed/jars -name '*jar'`; do CLASSPATH+=":$jar" done diff --git a/sentences.txt b/sentences.txt new file mode 100644 index 0000000000..fedf96c66e --- /dev/null +++ b/sentences.txt @@ -0,0 +1,3 @@ +Hello world! +What's up? +There is no cow level diff --git a/startTrigger.sh b/startTrigger.sh new file mode 100755 index 0000000000..373dbda93e --- /dev/null +++ b/startTrigger.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +./run spark.streaming.SentenceGenerator localhost 7078 sentences.txt 1 diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala new file mode 100644 index 0000000000..83a43d15cb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -0,0 +1,96 @@ +package spark.streaming + +import spark.Utils + +import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.conf.Configuration + +import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} + + +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { + val master = ssc.sc.master + val framework = ssc.sc.jobName + val sparkHome = ssc.sc.sparkHome + val jars = ssc.sc.jars + val graph = ssc.graph + val checkpointFile = ssc.checkpointFile + val checkpointInterval = ssc.checkpointInterval + + validate() + + def validate() { + assert(master != null, "Checkpoint.master is null") + assert(framework != null, "Checkpoint.framework is null") + assert(graph != null, "Checkpoint.graph is null") + assert(checkpointTime != null, "Checkpoint.checkpointTime is null") + } + + def saveToFile(file: String = checkpointFile) { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + //logInfo("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(this) + oos.close() + fs.close() + } + + def toBytes(): Array[Byte] = { + val bytes = Utils.serialize(this) + bytes + } +} + +object Checkpoint { + + def loadFromFile(file: String): Checkpoint = { + try { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (!fs.exists(path)) { + throw new Exception("Checkpoint file '" + file + "' does not exist") + } + val fis = fs.open(path) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + cp + } catch { + case e: Exception => + e.printStackTrace() + throw new Exception("Could not load checkpoint file '" + file + "'", e) + } + } + + def fromBytes(bytes: Array[Byte]): Checkpoint = { + val cp = Utils.deserialize[Checkpoint](bytes) + cp.validate() + cp + } +} + +class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) { + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + try { + return loader.loadClass(desc.getName()) + } catch { + case e: Exception => + } + return super.resolveClass(desc) + } +} diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala new file mode 100644 index 0000000000..61d088eddb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala @@ -0,0 +1,38 @@ +package spark.streaming + +import spark.{RDD, Partitioner} +import spark.rdd.CoGroupedRDD + +class CoGroupedDStream[K : ClassManifest]( + parents: Seq[DStream[(_, _)]], + partitioner: Partitioner + ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { + val part = partitioner + val rdds = parents.flatMap(_.getOrCompute(validTime)) + if (rdds.size > 0) { + val q = new CoGroupedRDD[K](rdds, part) + Some(q) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala new file mode 100644 index 0000000000..80150708fd --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala @@ -0,0 +1,18 @@ +package spark.streaming + +import spark.RDD + +/** + * An input stream that always returns the same RDD on each timestep. Useful for testing. + */ +class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) + extends InputDStream[T](ssc_) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + Some(rdd) + } +}
\ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala new file mode 100644 index 0000000000..12d7ba97ea --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -0,0 +1,660 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ + +import spark._ +import spark.SparkContext._ +import spark.rdd._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.util.concurrent.ArrayBlockingQueue +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import scala.Some + +abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) +extends Serializable with Logging { + + initLogging() + + /** + * ---------------------------------------------- + * Methods that must be implemented by subclasses + * ---------------------------------------------- + */ + + // Time by which the window slides in this DStream + def slideTime: Time + + // List of parent DStreams on which this DStream depends on + def dependencies: List[DStream[_]] + + // Key method that computes RDD for a valid time + def compute (validTime: Time): Option[RDD[T]] + + /** + * --------------------------------------- + * Other general fields and methods of DStream + * --------------------------------------- + */ + + // RDDs generated, marked as protected[streaming] so that testsuites can access it + protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] () + + // Time zero for the DStream + protected var zeroTime: Time = null + + // Duration for which the DStream will remember each RDD created + protected var rememberDuration: Time = null + + // Storage level of the RDDs in the stream + protected var storageLevel: StorageLevel = StorageLevel.NONE + + // Checkpoint level and checkpoint interval + protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint + protected var checkpointInterval: Time = null + + // Reference to whole DStream graph + protected var graph: DStreamGraph = null + + def isInitialized = (zeroTime != null) + + // Duration for which the DStream requires its parent DStream to remember each RDD created + def parentRememberDuration = rememberDuration + + // Change this RDD's storage level + def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): DStream[T] = { + if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { + // TODO: not sure this is necessary for DStreams + throw new UnsupportedOperationException( + "Cannot change storage level of an DStream after it was already assigned a level") + } + this.storageLevel = storageLevel + this.checkpointLevel = checkpointLevel + this.checkpointInterval = checkpointInterval + this + } + + // Set caching level for the RDDs created by this DStream + def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null) + + def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY) + + // Turn on the default caching level for this RDD + def cache(): DStream[T] = persist() + + /** + * This method initializes the DStream by setting the "zero" time, based on which + * the validity of future times is calculated. This method also recursively initializes + * its parent DStreams. + */ + protected[streaming] def initialize(time: Time) { + if (zeroTime != null && zeroTime != time) { + throw new Exception("ZeroTime is already initialized to " + zeroTime + + ", cannot initialize it again to " + time) + } + zeroTime = time + dependencies.foreach(_.initialize(zeroTime)) + logInfo("Initialized " + this) + } + + protected[streaming] def setContext(s: StreamingContext) { + if (ssc != null && ssc != s) { + throw new Exception("Context is already set in " + this + ", cannot set it again") + } + ssc = s + logInfo("Set context for " + this) + dependencies.foreach(_.setContext(ssc)) + } + + protected[streaming] def setGraph(g: DStreamGraph) { + if (graph != null && graph != g) { + throw new Exception("Graph is already set in " + this + ", cannot set it again") + } + graph = g + dependencies.foreach(_.setGraph(graph)) + } + + protected[streaming] def setRememberDuration(duration: Time = slideTime) { + if (duration == null) { + throw new Exception("Duration for remembering RDDs cannot be set to null for " + this) + } else if (rememberDuration != null && duration < rememberDuration) { + logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration + + " to " + duration + " for " + this) + } else { + rememberDuration = duration + dependencies.foreach(_.setRememberDuration(parentRememberDuration)) + logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) + } + } + + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ + protected def isTimeValid(time: Time): Boolean = { + if (!isInitialized) { + throw new Exception (this + " has not been initialized") + } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { + false + } else { + true + } + } + + /** + * This method either retrieves a precomputed RDD of this DStream, + * or computes the RDD (if the time is valid) + */ + def getOrCompute(time: Time): Option[RDD[T]] = { + // If this DStream was not initialized (i.e., zeroTime not set), then do it + // If RDD was already generated, then retrieve it from HashMap + generatedRDDs.get(time) match { + + // If an RDD was already generated and is being reused, then + // probably all RDDs in this DStream will be reused and hence should be cached + case Some(oldRDD) => Some(oldRDD) + + // if RDD was not generated, and if the time is valid + // (based on sliding time of this DStream), then generate the RDD + case None => { + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + generatedRDDs.put(time, newRDD) + Some(newRDD) + case None => + None + } + } else { + None + } + } + } + } + + /** + * This method generates a SparkStreaming job for the given time + * and may required to be overriden by subclasses + */ + def generateJob(time: Time): Option[Job] = { + getOrCompute(time) match { + case Some(rdd) => { + val jobFunc = () => { + val emptyFunc = { (iterator: Iterator[T]) => {} } + ssc.sc.runJob(rdd, emptyFunc) + } + Some(new Job(time, jobFunc)) + } + case None => None + } + } + + def forgetOldRDDs(time: Time) { + val keys = generatedRDDs.keys + var numForgotten = 0 + keys.foreach(t => { + if (t <= (time - rememberDuration)) { + generatedRDDs.remove(t) + numForgotten += 1 + //logInfo("Forgot RDD of time " + t + " from " + this) + } + }) + logInfo("Forgot " + numForgotten + " RDDs from " + this) + dependencies.foreach(_.forgetOldRDDs(time)) + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + logDebug(this.getClass().getSimpleName + ".writeObject used") + if (graph != null) { + graph.synchronized { + if (graph.checkpointInProgress) { + oos.defaultWriteObject() + } else { + val msg = "Object of " + this.getClass.getName + " is being serialized " + + " possibly as a part of closure of an RDD operation. This is because " + + " the DStream object is being referred to from within the closure. " + + " Please rewrite the RDD operation inside this DStream to avoid this. " + + " This has been enforced to avoid bloating of Spark tasks " + + " with unnecessary objects." + throw new java.io.NotSerializableException(msg) + } + } + } else { + throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.") + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + logDebug(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + } + + /** + * -------------- + * DStream operations + * -------------- + */ + def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { + new MappedDStream(this, ssc.sc.clean(mapFunc)) + } + + def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { + new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) + } + + def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + + def glom(): DStream[Array[T]] = new GlommedDStream(this) + + def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = { + new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) + } + + def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + + def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) + + def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2) + + def foreach(foreachFunc: T => Unit) { + val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + newStream + } + + def foreachRDD(foreachFunc: RDD[T] => Unit) { + foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + } + + def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { + val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + newStream + } + + def transformRDD[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { + transformRDD((r: RDD[T], t: Time) => transformFunc(r)) + } + + def transformRDD[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { + new TransformedDStream(this, ssc.sc.clean(transformFunc)) + } + + def toBlockingQueue() = { + val queue = new ArrayBlockingQueue[RDD[T]](10000) + this.foreachRDD(rdd => { + queue.add(rdd) + }) + queue + } + + def print() { + def foreachFunc = (rdd: RDD[T], time: Time) => { + val first11 = rdd.take(11) + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + first11.take(10).foreach(println) + if (first11.size > 10) println("...") + println() + } + val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + } + + def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime) + + def window(windowTime: Time, slideTime: Time): DStream[T] = { + new WindowedDStream(this, windowTime, slideTime) + } + + def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime) + + def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = { + this.window(windowTime, slideTime).reduce(reduceFunc) + } + + def reduceByWindow( + reduceFunc: (T, T) => T, + invReduceFunc: (T, T) => T, + windowTime: Time, + slideTime: Time + ): DStream[T] = { + this.map(x => (1, x)) + .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) + .map(_._2) + } + + def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = { + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) + } + + def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) + + def slice(interval: Interval): Seq[RDD[T]] = { + slice(interval.beginTime, interval.endTime) + } + + // Get all the RDDs between fromTime to toTime (both included) + def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + var time = toTime.floor(slideTime) + while (time >= zeroTime && time >= fromTime) { + getOrCompute(time) match { + case Some(rdd) => rdds += rdd + case None => //throw new Exception("Could not get RDD for time " + time) + } + time -= slideTime + } + rdds.toSeq + } + + def register() { + ssc.registerOutputStream(this) + } +} + + +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { + + override def dependencies = List() + + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") + ssc.graph.batchDuration + } + + def start() + + def stop() +} + + +/** + * TODO + */ + +class MappedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + mapFunc: T => U + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + + +/** + * TODO + */ + +class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + flatMapFunc: T => Traversable[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + + +/** + * TODO + */ + +class FilteredDStream[T: ClassManifest]( + parent: DStream[T], + filterFunc: T => Boolean + ) extends DStream[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + + +/** + * TODO + */ + +class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + mapPartFunc: Iterator[T] => Iterator[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) + } +} + + +/** + * TODO + */ + +class GlommedDStream[T: ClassManifest](parent: DStream[T]) + extends DStream[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} + + +/** + * TODO + */ + +class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: DStream[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner + ) extends DStream [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) + case None => None + } + } +} + + +/** + * TODO + */ + +class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + mapValueFunc: V => U + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) + } +} + + +/** + * TODO + */ + +class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + flatMapValueFunc: V => TraversableOnce[U] + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) + } +} + + + +/** + * TODO + */ + +class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) + extends DStream[T](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime: Time = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} + + +/** + * TODO + */ + +class PerElementForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: T => Unit + ) extends DStream[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + val sparkJobFunc = { + (iterator: Iterator[T]) => iterator.foreach(foreachFunc) + } + ssc.sc.runJob(rdd, sparkJobFunc) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} + + +/** + * TODO + */ + +class PerRDDForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: (RDD[T], Time) => Unit + ) extends DStream[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} + + +/** + * TODO + */ + +class TransformedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + transformFunc: (RDD[T], Time) => RDD[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala new file mode 100644 index 0000000000..ac44d7a2a6 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -0,0 +1,124 @@ +package spark.streaming + +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import collection.mutable.ArrayBuffer +import spark.Logging + +final class DStreamGraph extends Serializable with Logging { + initLogging() + + private val inputStreams = new ArrayBuffer[InputDStream[_]]() + private val outputStreams = new ArrayBuffer[DStream[_]]() + + private[streaming] var zeroTime: Time = null + private[streaming] var batchDuration: Time = null + private[streaming] var rememberDuration: Time = null + private[streaming] var checkpointInProgress = false + + def start(time: Time) { + this.synchronized { + if (zeroTime != null) { + throw new Exception("DStream graph computation already started") + } + zeroTime = time + outputStreams.foreach(_.initialize(zeroTime)) + outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values + if (rememberDuration != null) { + // if custom rememberDuration has been provided, set the rememberDuration + outputStreams.foreach(_.setRememberDuration(rememberDuration)) + } + inputStreams.par.foreach(_.start()) + } + } + + def stop() { + this.synchronized { + inputStreams.par.foreach(_.stop()) + } + } + + private[streaming] def setContext(ssc: StreamingContext) { + this.synchronized { + outputStreams.foreach(_.setContext(ssc)) + } + } + + def setBatchDuration(duration: Time) { + this.synchronized { + if (batchDuration != null) { + throw new Exception("Batch duration already set as " + batchDuration + + ". cannot set it again.") + } + } + batchDuration = duration + } + + def setRememberDuration(duration: Time) { + this.synchronized { + if (rememberDuration != null) { + throw new Exception("Batch duration already set as " + batchDuration + + ". cannot set it again.") + } + } + rememberDuration = duration + } + + def addInputStream(inputStream: InputDStream[_]) { + this.synchronized { + inputStream.setGraph(this) + inputStreams += inputStream + } + } + + def addOutputStream(outputStream: DStream[_]) { + this.synchronized { + outputStream.setGraph(this) + outputStreams += outputStream + } + } + + def getInputStreams() = inputStreams.toArray + + def getOutputStreams() = outputStreams.toArray + + def generateRDDs(time: Time): Seq[Job] = { + this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + } + } + + def forgetOldRDDs(time: Time) { + this.synchronized { + outputStreams.foreach(_.forgetOldRDDs(time)) + } + } + + def validate() { + this.synchronized { + assert(batchDuration != null, "Batch duration has not been set") + assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") + } + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + this.synchronized { + logDebug("DStreamGraph.writeObject used") + checkpointInProgress = true + oos.defaultWriteObject() + checkpointInProgress = false + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + this.synchronized { + logDebug("DStreamGraph.readObject used") + checkpointInProgress = true + ois.defaultReadObject() + checkpointInProgress = false + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala new file mode 100644 index 0000000000..537ec88047 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -0,0 +1,87 @@ +package spark.streaming + +import spark.RDD +import spark.rdd.UnionRDD + +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import java.io.{ObjectInputStream, IOException} + + +class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + @transient ssc_ : StreamingContext, + directory: String, + filter: PathFilter = FileInputDStream.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + var lastModTime: Long = 0 + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val newFilter = new PathFilter() { + var latestModTime = 0L + + def accept(path: Path): Boolean = { + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime <= lastModTime) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + } + return true + } + } + } + + val newFiles = fs.listStatus(path, newFilter) + logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + if (newFiles.length > 0) { + lastModTime = newFilter.latestModTime + } + val newRDD = new UnionRDD(ssc.sc, newFiles.map( + file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) + Some(newRDD) + } +} + +object FileInputDStream { + val defaultPathFilter = new PathFilter with Serializable { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala new file mode 100644 index 0000000000..ffb7725ac9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -0,0 +1,50 @@ +package spark.streaming + +case class Interval(beginTime: Time, endTime: Time) { + def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) + + def duration(): Time = endTime - beginTime + + def + (time: Time): Interval = { + new Interval(beginTime + time, endTime + time) + } + + def - (time: Time): Interval = { + new Interval(beginTime - time, endTime - time) + } + + def < (that: Interval): Boolean = { + if (this.duration != that.duration) { + throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]") + } + this.endTime < that.endTime + } + + def <= (that: Interval) = (this < that || this == that) + + def > (that: Interval) = !(this <= that) + + def >= (that: Interval) = !(this < that) + + def next(): Interval = { + this + (endTime - beginTime) + } + + def isZero = (beginTime.isZero && endTime.isZero) + + def toFormattedString = beginTime.toFormattedString + "-" + endTime.toFormattedString + + override def toString = "[" + beginTime + ", " + endTime + "]" +} + +object Interval { + def zero() = new Interval (Time.zero, Time.zero) + + def currentInterval(intervalDuration: Time): Interval = { + val time = Time(System.currentTimeMillis) + val intervalBegin = time.floor(intervalDuration) + Interval(intervalBegin, intervalBegin + intervalDuration) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala new file mode 100644 index 0000000000..0bcb6fd8dc --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -0,0 +1,22 @@ +package spark.streaming + +import java.util.concurrent.atomic.AtomicLong + +class Job(val time: Time, func: () => _) { + val id = Job.getNewId() + def run(): Long = { + val startTime = System.currentTimeMillis + func() + val stopTime = System.currentTimeMillis + (stopTime - startTime) + } + + override def toString = "streaming job " + id + " @ " + time +} + +object Job { + val id = new AtomicLong(0) + + def getNewId() = id.getAndIncrement() +} + diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala new file mode 100644 index 0000000000..9bf9251519 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -0,0 +1,32 @@ +package spark.streaming + +import spark.Logging +import spark.SparkEnv +import java.util.concurrent.Executors + + +class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { + + class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { + def run() { + SparkEnv.set(ssc.env) + try { + val timeTaken = job.run() + logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( + (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) + } catch { + case e: Exception => + logError("Running " + job + " failed", e) + } + } + } + + initLogging() + + val jobExecutor = Executors.newFixedThreadPool(numThreads) + + def runJob(job: Job) { + jobExecutor.execute(new JobHandler(ssc, job)) + logInfo("Added " + job + " to queue") + } +} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala new file mode 100644 index 0000000000..f3f4c3ab13 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -0,0 +1,151 @@ +package spark.streaming + +import scala.collection.mutable.ArrayBuffer + +import spark.{Logging, SparkEnv, RDD} +import spark.rdd.BlockRDD +import spark.storage.StorageLevel + +import java.nio.ByteBuffer + +import akka.actor.{Props, Actor} +import akka.pattern.ask +import akka.dispatch.Await +import akka.util.duration._ + +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) + extends InputDStream[T](ssc_) { + + // This is an unique identifier that is used to match the network receiver with the + // corresponding network input stream. + val id = ssc.getNewNetworkStreamId() + + /** + * This method creates the receiver object that will be sent to the workers + * to receive data. This method needs to defined by any specific implementation + * of a NetworkInputDStream. + */ + def createReceiver(): NetworkReceiver[T] + + // Nothing to start or stop as both taken care of by the NetworkInputTracker. + def start() {} + + def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + Some(new BlockRDD[T](ssc.sc, blockIds)) + } +} + + +sealed trait NetworkReceiverMessage +case class StopReceiver(msg: String) extends NetworkReceiverMessage +case class ReportBlock(blockId: String) extends NetworkReceiverMessage +case class ReportError(msg: String) extends NetworkReceiverMessage + +abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging { + + initLogging() + + lazy protected val env = SparkEnv.get + + lazy protected val actor = env.actorSystem.actorOf( + Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + + lazy protected val receivingThread = Thread.currentThread() + + /** This method will be called to start receiving data. */ + protected def onStart() + + /** This method will be called to stop receiving data. */ + protected def onStop() + + /** + * This method starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc required to receiver the data. + */ + def start() { + try { + // Access the lazy vals to materialize them + env + actor + receivingThread + + // Call user-defined onStart() + onStart() + } catch { + case ie: InterruptedException => + logInfo("Receiving thread interrupted") + //println("Receiving thread interrupted") + case e: Exception => + stopOnError(e) + } + } + + /** + * This method stops the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). Then it calls the user-defined + * onStop() method to stop other threads and/or do cleanup. + */ + def stop() { + receivingThread.interrupt() + onStop() + //TODO: terminate the actor + } + + /** + * This method stops the receiver and reports to exception to the tracker. + * This should be called whenever an exception has happened on any thread + * of the receiver. + */ + protected def stopOnError(e: Exception) { + logError("Error receiving data", e) + stop() + actor ! ReportError(e.toString) + } + + /** + * This method pushes a block (as iterator of values) into the block manager. + */ + protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) { + val buffer = new ArrayBuffer[T] ++ iterator + env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) + actor ! ReportBlock(blockId) + } + + /** + * This method pushes a block (as bytes) into the block manager. + */ + protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + env.blockManager.putBytes(blockId, bytes, level) + actor ! ReportBlock(blockId) + } + + /** A helper actor that communicates with the NetworkInputTracker */ + private class NetworkReceiverActor extends Actor { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + val tracker = env.actorSystem.actorFor(url) + val timeout = 5.seconds + + override def preStart() { + val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive() = { + case ReportBlock(blockId) => + tracker ! AddBlocks(streamId, Array(blockId)) + case ReportError(msg) => + tracker ! DeregisterReceiver(streamId, msg) + case StopReceiver(msg) => + stop() + tracker ! DeregisterReceiver(streamId, msg) + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala new file mode 100644 index 0000000000..07ef79415d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -0,0 +1,118 @@ +package spark.streaming + +import spark.Logging +import spark.SparkEnv + +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +trait NetworkInputTrackerMessage +case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage +case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage + + +class NetworkInputTracker( + @transient ssc: StreamingContext, + @transient networkInputStreams: Array[NetworkInputDStream[_]]) + extends Logging { + + val networkInputStreamIds = networkInputStreams.map(_.id).toArray + val receiverExecutor = new ReceiverExecutor() + val receiverInfo = new HashMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds + + var currentTime: Time = null + + def start() { + ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + } + + def stop() { + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + } + + def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) + } + result.toArray + } + + private class NetworkInputTrackerActor extends Actor { + def receive = { + case RegisterReceiver(streamId, receiverActor) => { + if (!networkInputStreamIds.contains(streamId)) { + throw new Exception("Register received for unexpected id " + streamId) + } + receiverInfo += ((streamId, receiverActor)) + logInfo("Registered receiver for network stream " + streamId) + sender ! true + } + case AddBlocks(streamId, blockIds) => { + val tmp = receivedBlockIds.synchronized { + if (!receivedBlockIds.contains(streamId)) { + receivedBlockIds += ((streamId, new Queue[String])) + } + receivedBlockIds(streamId) + } + tmp.synchronized { + tmp ++= blockIds + } + } + case DeregisterReceiver(streamId, msg) => { + receiverInfo -= streamId + logInfo("De-registered receiver for network stream " + streamId + + " with message " + msg) + //TODO: Do something about the corresponding NetworkInputDStream + } + } + } + + class ReceiverExecutor extends Thread { + val env = ssc.env + + override def run() { + try { + SparkEnv.set(env) + startReceivers() + } catch { + case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") + } finally { + stopReceivers() + } + } + + def startReceivers() { + val receivers = networkInputStreams.map(_.createReceiver()) + val tempRDD = ssc.sc.makeRDD(receivers, receivers.size) + + val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { + if (!iterator.hasNext) { + throw new Exception("Could not start receiver as details not found.") + } + iterator.next().start() + } + ssc.sc.runJob(tempRDD, startReceiver) + } + + def stopReceivers() { + //implicit val ec = env.actorSystem.dispatcher + receiverInfo.values.foreach(_ ! StopReceiver) + //val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList + //val futureOfList = Future.sequence(listOfFutures) + //Await.result(futureOfList, timeout) + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala new file mode 100644 index 0000000000..ce1f4ad0a0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -0,0 +1,236 @@ +package spark.streaming + +import scala.collection.mutable.ArrayBuffer +import spark.{Manifests, RDD, Partitioner, HashPartitioner} +import spark.streaming.StreamingContext._ +import javax.annotation.Nullable + +class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)]) +extends Serializable { + + def ssc = self.ssc + + def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { + new HashPartitioner(numPartitions) + } + + /* ---------------------------------- */ + /* DStream operations for key-value pairs */ + /* ---------------------------------- */ + + def groupByKey(): DStream[(K, Seq[V])] = { + groupByKey(defaultPartitioner()) + } + + def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { + groupByKey(defaultPartitioner(numPartitions)) + } + + def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { + val createCombiner = (v: V) => ArrayBuffer[V](v) + val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) + val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) + combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner).asInstanceOf[DStream[(K, Seq[V])]] + } + + def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { + reduceByKey(reduceFunc, defaultPartitioner()) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { + reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) + } + + def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) + } + + private def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner) : ShuffledDStream[K, V, C] = { + new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) + } + + def groupByKeyAndWindow(windowTime: Time, slideTime: Time): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner()) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int + ): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner(numPartitions)) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + partitioner: Partitioner + ): DStream[(K, Seq[V])] = { + self.window(windowTime, slideTime).groupByKey(partitioner) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time + ): DStream[(K, V)] = { + reduceByKeyAndWindow(reduceFunc, windowTime, self.slideTime, defaultPartitioner()) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time + ): DStream[(K, V)] = { + reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner()) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int + ): DStream[(K, V)] = { + reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + partitioner: Partitioner + ): DStream[(K, V)] = { + self.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) + } + + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be + // "subtracted using invReduceFunc. + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time + ): DStream[(K, V)] = { + + reduceByKeyAndWindow( + reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner()) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int + ): DStream[(K, V)] = { + + reduceByKeyAndWindow( + reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + partitioner: Partitioner + ): DStream[(K, V)] = { + + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) + new ReducedWindowedDStream[K, V]( + self, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner) + } + + // TODO: + // + // + // + // + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Seq[V], Option[S]) => Option[S] + ): DStream[(K, S)] = { + updateStateByKey(updateFunc, defaultPartitioner()) + } + + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Seq[V], Option[S]) => Option[S], + numPartitions: Int + ): DStream[(K, S)] = { + updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) + } + + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Seq[V], Option[S]) => Option[S], + partitioner: Partitioner + ): DStream[(K, S)] = { + val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + } + updateStateByKey(newUpdateFunc, partitioner, true) + } + + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + rememberPartitioner: Boolean + ): DStream[(K, S)] = { + new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) + } + + + def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = { + new MapValuesDStream[K, V, U](self, mapValuesFunc) + } + + def flatMapValues[U: ClassManifest]( + flatMapValuesFunc: V => TraversableOnce[U] + ): DStream[(K, U)] = { + new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc) + } + + def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner()) + } + + def cogroup[W: ClassManifest]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (Seq[V], Seq[W]))] = { + + val cgd = new CoGroupedDStream[K]( + Seq(self.asInstanceOf[DStream[(_, _)]], other.asInstanceOf[DStream[(_, _)]]), + partitioner + ) + val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)( + classManifest[K], + Manifests.seqSeqManifest + ) + pdfs.mapValues { + case Seq(vs, ws) => + (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) + } + } + + def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { + join[W](other, defaultPartitioner()) + } + + def join[W: ClassManifest](other: DStream[(K, W)], partitioner: Partitioner): DStream[(K, (V, W))] = { + this.cogroup(other, partitioner) + .flatMapValues{ + case (vs, ws) => + for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala new file mode 100644 index 0000000000..bb86e51932 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala @@ -0,0 +1,40 @@ +package spark.streaming + +import spark.RDD +import spark.rdd.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer + +class QueueInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputDStream[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + if (oneAtATime) { + Some(buffer.first) + } else { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala new file mode 100644 index 0000000000..e022b85fbe --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -0,0 +1,83 @@ +package spark.streaming + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue +import spark._ +import spark.storage.StorageLevel + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +class RawInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + def createReceiver(): NetworkReceiver[T] = { + new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + } +} + +class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any](streamId) { + + var blockPushingThread: Thread = null + + def onStart() { + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + blockPushingThread = new DaemonThread { + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + pushBlock(blockId, buffer, storageLevel) + } + } + } + blockPushingThread.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + def onStop() { + blockPushingThread.interrupt() + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala new file mode 100644 index 0000000000..1c57d5f855 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -0,0 +1,143 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ + +import spark.RDD +import spark.rdd.UnionRDD +import spark.rdd.CoGroupedRDD +import spark.Partitioner +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import collection.SeqProxy + +class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( + parent: DStream[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + _windowTime: Time, + _slideTime: Time, + partitioner: Partitioner + ) extends DStream[(K,V)](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + + def windowTime: Time = _windowTime + + override def dependencies = List(reducedStream) + + override def slideTime: Time = _slideTime + + //TODO: This is wrong. This should depend on the checkpointInterval + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): DStream[(K,V)] = { + super.persist(storageLevel, checkpointLevel, checkpointInterval) + reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + this + } + + protected[streaming] override def setRememberDuration(time: Time) { + if (rememberDuration == null || rememberDuration < time) { + rememberDuration = time + dependencies.foreach(_.setRememberDuration(rememberDuration + windowTime)) + } + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc + + val currentTime = validTime + val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) + val previousWindow = currentWindow - slideTime + + logDebug("Window time = " + windowTime) + logDebug("Slide time = " + slideTime) + logDebug("ZeroTime = " + zeroTime) + logDebug("Current window = " + currentWindow) + logDebug("Previous window = " + previousWindow) + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + + // Get the RDDs of the reduced values in "old time steps" + val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) + logDebug("# old RDDs = " + oldRDDs.size) + + // Get the RDDs of the reduced values in "new time steps" + val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) + logDebug("# new RDDs = " + newRDDs.size) + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + + // Make the list of RDDs that needs to cogrouped together for reducing their reduced values + val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs + + // Cogroup the reduced RDDs and merge the reduced values + val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size + + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") + } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } + + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) + + Some(mergedValuesRDD) + } + + +} + + diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala new file mode 100644 index 0000000000..7d52e2eddf --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -0,0 +1,69 @@ +package spark.streaming + +import util.{ManualClock, RecurringTimer, Clock} +import spark.SparkEnv +import spark.Logging + +import scala.collection.mutable.HashMap + + +sealed trait SchedulerMessage +case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage + +class Scheduler(ssc: StreamingContext) +extends Logging { + + initLogging() + + val graph = ssc.graph + val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt + val jobManager = new JobManager(ssc, concurrentJobs) + val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") + val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] + val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_)) + + def start() { + // If context was started from checkpoint, then restart timer such that + // this timer's triggers occur at the same time as the original timer. + // Otherwise just start the timer from scratch, and initialize graph based + // on this first trigger time of the timer. + if (ssc.isCheckpointPresent) { + // If manual clock is being used for testing, then + // set manual clock to the last checkpointed time + if (clock.isInstanceOf[ManualClock]) { + val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds + clock.asInstanceOf[ManualClock].setTime(lastTime) + } + timer.restart(graph.zeroTime.milliseconds) + logInfo("Scheduler's timer restarted") + } else { + val firstTime = Time(timer.start()) + graph.start(firstTime - ssc.graph.batchDuration) + logInfo("Scheduler's timer started") + } + logInfo("Scheduler started") + } + + def stop() { + timer.stop() + graph.stop() + logInfo("Scheduler stopped") + } + + def generateRDDs(time: Time) { + SparkEnv.set(ssc.env) + logInfo("\n-----------------------------------------------------\n") + graph.generateRDDs(time).foreach(submitJob) + logInfo("Generated RDDs for time " + time) + graph.forgetOldRDDs(time) + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + ssc.doCheckpoint(time) + logInfo("Checkpointed at time " + time) + } + } + + def submitJob(job: Job) { + jobManager.runJob(job) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala new file mode 100644 index 0000000000..b566200273 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -0,0 +1,173 @@ +package spark.streaming + +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + +import java.io._ +import java.net.Socket +import java.util.concurrent.ArrayBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.Serializable + +class SocketInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def createReceiver(): NetworkReceiver[T] = { + new SocketReceiver(id, host, port, bytesToObjects, storageLevel) + } +} + + +class SocketReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T](streamId) { + + lazy protected val dataHandler = new DataHandler(this) + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + dataHandler.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } + + protected def onStop() { + dataHandler.stop() + } + + /** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { + case class Block(id: String, iterator: Iterator[T]) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + streamId + "- " + (time - blockInterval) + val newBlock = new Block(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + pushBlock(block.id, block.iterator, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } +} + + +object SocketReceiver { + + /** + * This methods translates the data from an inputstream (say, from a socket) + * to '\n' delimited strings and returns an iterator to access the strings. + */ + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + if (nextValue == null) { + finished = true + } + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + getNext() + if (finished) { + dataInputStream.close() + } + } + } + !finished + } + + override def next(): String = { + if (finished) { + throw new NoSuchElementException("End of stream") + } + if (!gotNext) { + getNext() + } + gotNext = false + nextValue + } + } + iterator + } +} diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala new file mode 100644 index 0000000000..086752ac55 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -0,0 +1,130 @@ +package spark.streaming + +import spark.RDD +import spark.rdd.BlockRDD +import spark.Partitioner +import spark.rdd.MapPartitionsRDD +import spark.SparkContext._ +import spark.storage.StorageLevel + + +class StateRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: Iterator[T] => Iterator[U], + rememberPartitioner: Boolean + ) extends MapPartitionsRDD[U, T](prev, f) { + override val partitioner = if (rememberPartitioner) prev.partitioner else None +} + +class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + rememberPartitioner: Boolean + ) extends DStream[(K, S)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime = parent.slideTime + + override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { + generatedRDDs.get(time) match { + case Some(oldRDD) => { + if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { + val r = oldRDD + val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) + val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { + override val partitioner = oldRDD.partitioner + } + generatedRDDs.update(time, checkpointedRDD) + logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id) + Some(checkpointedRDD) + } else { + Some(oldRDD) + } + } + case None => { + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + generatedRDDs.put(time, newRDD) + Some(newRDD) + } + case None => { + None + } + } + } else { + None + } + } + } + } + + override def compute(validTime: Time): Option[RDD[(K, S)]] = { + + // Try to get the previous state RDD + getOrCompute(validTime - slideTime) match { + + case Some(prevStateRDD) => { // If previous state RDD exists + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption) + }) + updateFuncLocal(i) + } + val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) + val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner) + //logDebug("Generating state RDD for time " + validTime) + return Some(stateRDD) + } + case None => { // If parent RDD does not exist, then return old state RDD + return Some(prevStateRDD) + } + } + } + + case None => { // If previous session RDD does not exist (first input data) + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) + } + + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner) + //logDebug("Generating state RDD for time " + validTime + " (first)") + return Some(sessionRDD) + } + case None => { // If parent RDD does not exist, then nothing to do! + //logDebug("Not generating state RDD (no previous state, no parent)") + return None + } + } + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala new file mode 100644 index 0000000000..7c7b3afe47 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -0,0 +1,229 @@ +package spark.streaming + +import spark.RDD +import spark.Logging +import spark.SparkEnv +import spark.SparkContext +import spark.storage.StorageLevel + +import scala.collection.mutable.Queue + +import java.io.InputStream +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat + +class StreamingContext ( + sc_ : SparkContext, + cp_ : Checkpoint + ) extends Logging { + + def this(sparkContext: SparkContext) = this(sparkContext, null) + + def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = + this(new SparkContext(master, frameworkName, sparkHome, jars), null) + + def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + + def this(cp_ : Checkpoint) = this(null, cp_) + + initLogging() + + if (sc_ == null && cp_ == null) { + throw new Exception("Streaming Context cannot be initilalized with " + + "both SparkContext and checkpoint as null") + } + + val isCheckpointPresent = (cp_ != null) + + val sc: SparkContext = { + if (isCheckpointPresent) { + new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars) + } else { + sc_ + } + } + + val env = SparkEnv.get + + val graph: DStreamGraph = { + if (isCheckpointPresent) { + + cp_.graph.setContext(this) + cp_.graph + } else { + new DStreamGraph() + } + } + + val nextNetworkInputStreamId = new AtomicInteger(0) + var networkInputTracker: NetworkInputTracker = null + + private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null + private[streaming] var receiverJobThread: Thread = null + private[streaming] var scheduler: Scheduler = null + + def setBatchDuration(duration: Time) { + graph.setBatchDuration(duration) + } + + def setRememberDuration(duration: Time) { + graph.setRememberDuration(duration) + } + + def setCheckpointDetails(file: String, interval: Time) { + checkpointFile = file + checkpointInterval = interval + } + + private[streaming] def getInitialCheckpoint(): Checkpoint = { + if (isCheckpointPresent) cp_ else null + } + + private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + + def networkTextStream( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[String] = { + networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) + } + + def networkStream[T: ClassManifest]( + hostname: String, + port: Int, + converter: (InputStream) => Iterator[T], + storageLevel: StorageLevel + ): DStream[T] = { + val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel) + graph.addInputStream(inputStream) + inputStream + } + + def rawNetworkStream[T: ClassManifest]( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 + ): DStream[T] = { + val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) + graph.addInputStream(inputStream) + inputStream + } + + /** + * This function creates a input stream that monitors a Hadoop-compatible + * for new files and executes the necessary processing on them. + */ + def fileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ](directory: String): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, directory) + graph.addInputStream(inputStream) + inputStream + } + + def textFileStream(directory: String): DStream[String] = { + fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) + } + + /** + * This function create a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue + */ + def queueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean = true, + defaultRDD: RDD[T] = null + ): DStream[T] = { + val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) + graph.addInputStream(inputStream) + inputStream + } + + def queueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = { + val queue = new Queue[RDD[T]] + val inputStream = queueStream(queue, true, null) + queue ++= array + inputStream + } + + /** + * This function registers a InputDStream as an input stream that will be + * started (InputDStream.start() called) to get the input data streams. + */ + def registerInputStream(inputStream: InputDStream[_]) { + graph.addInputStream(inputStream) + } + + /** + * This function registers a DStream as an output stream that will be + * computed every interval. + */ + def registerOutputStream(outputStream: DStream[_]) { + graph.addOutputStream(outputStream) + } + + def validate() { + assert(graph != null, "Graph is null") + graph.validate() + } + + /** + * This function starts the execution of the streams. + */ + def start() { + validate() + + val networkInputStreams = graph.getInputStreams().filter(s => s match { + case n: NetworkInputDStream[_] => true + case _ => false + }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray + + if (networkInputStreams.length > 0) { + // Start the network input tracker (must start before receivers) + networkInputTracker = new NetworkInputTracker(this, networkInputStreams) + networkInputTracker.start() + } + + Thread.sleep(1000) + + // Start the scheduler + scheduler = new Scheduler(this) + scheduler.start() + } + + /** + * This function stops the execution of the streams. + */ + def stop() { + try { + if (scheduler != null) scheduler.stop() + if (networkInputTracker != null) networkInputTracker.stop() + if (receiverJobThread != null) receiverJobThread.interrupt() + sc.stop() + } catch { + case e: Exception => logWarning("Error while stopping", e) + } + + logInfo("StreamingContext stopped") + } + + def doCheckpoint(currentTime: Time) { + new Checkpoint(this, currentTime).saveToFile(checkpointFile) + } +} + + +object StreamingContext { + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { + new PairDStreamFunctions[K, V](stream) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala new file mode 100644 index 0000000000..9ddb65249a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -0,0 +1,56 @@ +package spark.streaming + +case class Time(millis: Long) { + + def < (that: Time): Boolean = (this.millis < that.millis) + + def <= (that: Time): Boolean = (this.millis <= that.millis) + + def > (that: Time): Boolean = (this.millis > that.millis) + + def >= (that: Time): Boolean = (this.millis >= that.millis) + + def + (that: Time): Time = Time(millis + that.millis) + + def - (that: Time): Time = Time(millis - that.millis) + + def * (times: Int): Time = Time(millis * times) + + def floor(that: Time): Time = { + val t = that.millis + val m = math.floor(this.millis / t).toLong + Time(m * t) + } + + def isMultipleOf(that: Time): Boolean = + (this.millis % that.millis == 0) + + def isZero: Boolean = (this.millis == 0) + + override def toString: String = (millis.toString + " ms") + + def toFormattedString: String = millis.toString + + def milliseconds: Long = millis +} + +object Time { + val zero = Time(0) + + implicit def toTime(long: Long) = Time(long) + + implicit def toLong(time: Time) = time.milliseconds +} + +object Milliseconds { + def apply(milliseconds: Long) = Time(milliseconds) +} + +object Seconds { + def apply(seconds: Long) = Time(seconds * 1000) +} + +object Minutes { + def apply(minutes: Long) = Time(minutes * 60000) +} + diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala new file mode 100644 index 0000000000..ce89a3f99b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -0,0 +1,36 @@ +package spark.streaming + +import spark.RDD +import spark.rdd.UnionRDD + + +class WindowedDStream[T: ClassManifest]( + parent: DStream[T], + _windowTime: Time, + _slideTime: Time) + extends DStream[T](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + def windowTime: Time = _windowTime + + override def dependencies = List(parent) + + override def slideTime: Time = _slideTime + + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) + Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala new file mode 100644 index 0000000000..d2fdabd659 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala @@ -0,0 +1,32 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object CountRaw { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: CountRaw <master> <numStreams> <host> <port> <batchMillis>") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "CountRaw") + ssc.setBatchDuration(Milliseconds(batchMillis)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnionDStream(rawStreams) + union.map(_.length + 2).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala new file mode 100644 index 0000000000..d68611abd6 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala @@ -0,0 +1,47 @@ +package spark.streaming.examples + +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext._ +import spark.streaming.Seconds +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + + +object FileStream { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: FileStream <master> <new HDFS compatible directory>") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "FileStream") + ssc.setBatchDuration(Seconds(2)) + + // Create the new directory + val directory = new Path(args(1)) + val fs = directory.getFileSystem(new Configuration()) + if (fs.exists(directory)) throw new Exception("This directory already exists") + fs.mkdirs(directory) + fs.deleteOnExit(directory) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val inputStream = ssc.textFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + + // Creating new files in the directory + val text = "This is a text file" + for (i <- 1 to 30) { + ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) + .saveAsTextFile(new Path(directory, i.toString).toString) + Thread.sleep(1000) + } + Thread.sleep(5000) // Waiting for the file to be processed + ssc.stop() + System.exit(0) + } +}
\ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala new file mode 100644 index 0000000000..df96a811da --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -0,0 +1,76 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +object FileStreamWithCheckpoint { + + def main(args: Array[String]) { + + if (args.size != 3) { + println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>") + println("FileStreamWithCheckpoint restart <directory> <checkpoint file>") + System.exit(-1) + } + + val directory = new Path(args(1)) + val checkpointFile = args(2) + + val ssc: StreamingContext = { + + if (args(0) == "restart") { + + // Recreated streaming context from specified checkpoint file + new StreamingContext(checkpointFile) + + } else { + + // Create directory if it does not exist + val fs = directory.getFileSystem(new Configuration()) + if (!fs.exists(directory)) fs.mkdirs(directory) + + // Create new streaming context + val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") + ssc_.setBatchDuration(Seconds(1)) + ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + + // Setup the streaming computation + val inputStream = ssc_.textFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + + ssc_ + } + } + + // Start the stream computation + startFileWritingThread(directory.toString) + ssc.start() + } + + def startFileWritingThread(directory: String) { + + val fs = new Path(directory).getFileSystem(new Configuration()) + + val fileWritingThread = new Thread() { + override def run() { + val r = new scala.util.Random() + val text = "This is a sample text file with a random number " + while(true) { + val number = r.nextInt() + val file = new Path(directory, number.toString) + val fos = fs.create(file) + fos.writeChars(text + number) + fos.close() + println("Created text file " + file) + Thread.sleep(1000) + } + } + } + fileWritingThread.start() + } + +} diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala new file mode 100644 index 0000000000..b1faa65c17 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala @@ -0,0 +1,64 @@ +package spark.streaming.examples + +import spark.SparkContext +import SparkContext._ +import spark.streaming._ +import StreamingContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object Grep2 { + + def warmup(sc: SparkContext) { + (0 until 10).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(x => (x % 337, x % 1331)) + .reduceByKey(_ + _) + .count() + } + } + + def main (args: Array[String]) { + + if (args.length != 6) { + println ("Usage: Grep2 <host> <file> <mapTasks> <reduceTasks> <batchMillis> <chkptMillis>") + System.exit(1) + } + + val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args + + val batchDuration = Milliseconds(batchMillis.toLong) + + val ssc = new StreamingContext(master, "Grep2") + ssc.setBatchDuration(batchDuration) + + //warmup(ssc.sc) + + val data = ssc.sc.textFile(file, mapTasks.toInt).persist( + new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas + println("Data count: " + data.count()) + println("Data count: " + data.count()) + println("Data count: " + data.count()) + + val sentences = new ConstantInputDStream(ssc, data) + ssc.registerInputStream(sentences) + + sentences.filter(_.contains("Culpepper")).count().foreachRDD(r => + println("Grep count: " + r.collect().mkString)) + + ssc.start() + + while(true) { Thread.sleep(1000) } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala new file mode 100644 index 0000000000..b1e1a613fe --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -0,0 +1,33 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object GrepRaw { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: GrepRaw <master> <numStreams> <host> <port> <batchMillis>") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "GrepRaw") + ssc.setBatchDuration(Milliseconds(batchMillis)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnionDStream(rawStreams) + union.filter(_.contains("Culpepper")).count().foreachRDD(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala new file mode 100644 index 0000000000..2af51bad28 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala @@ -0,0 +1,41 @@ +package spark.streaming.examples + +import spark.RDD +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext._ +import spark.streaming.Seconds + +import scala.collection.mutable.SynchronizedQueue + +object QueueStream { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: QueueStream <master>") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "QueueStream") + ssc.setBatchDuration(Seconds(1)) + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputDStream and use it do some processing + val inputStream = ssc.queueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sc.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +}
\ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala new file mode 100644 index 0000000000..57fd10f0a5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -0,0 +1,95 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.SparkContext +import spark.SparkContext._ +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +import WordCount2_ExtraFunctions._ + +object TopKWordCountRaw { + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + + def main(args: Array[String]) { + if (args.length != 7) { + System.err.println("Usage: TopKWordCountRaw <master> <streams> <host> <port> <batchMs> <chkptMs> <reduces>") + System.exit(1) + } + + val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), + IntParam(chkptMs), IntParam(reduces)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "TopKWordCountRaw") + ssc.setBatchDuration(Milliseconds(batchMs)) + + // Make sure some tasks have started on each node + moreWarmup(ssc.sc) + + val rawStreams = (1 to streams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnionDStream(rawStreams) + + val windowedCounts = union.mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, + Milliseconds(chkptMs)) + //windowedCounts.print() // TODO: something else? + + def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { + val taken = new Array[(String, Long)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, Long) = null + var swap: (String, Long) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 50 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + +// windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala new file mode 100644 index 0000000000..0d2e62b955 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -0,0 +1,115 @@ +package spark.streaming.examples + +import spark.SparkContext +import SparkContext._ +import spark.streaming._ +import StreamingContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object WordCount2_ExtraFunctions { + + def add(v1: Long, v2: Long) = (v1 + v2) + + def subtract(v1: Long, v2: Long) = (v1 - v2) + + def max(v1: Long, v2: Long) = math.max(v1, v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { + //val map = new java.util.HashMap[String, Long] + val map = new OLMap[String] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.getLong(w) + map.put(w, c + 1) +/* + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } +*/ + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator.map{case (k, v) => (k, v)} + } +} + +object WordCount2 { + + def warmup(sc: SparkContext) { + (0 until 3).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(x => (x % 337, x % 1331)) + .reduceByKey(_ + _, 100) + .count() + } + } + + def main (args: Array[String]) { + + if (args.length != 6) { + println ("Usage: WordCount2 <host> <file> <mapTasks> <reduceTasks> <batchMillis> <chkptMillis>") + System.exit(1) + } + + val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args + + val batchDuration = Milliseconds(batchMillis.toLong) + + val ssc = new StreamingContext(master, "WordCount2") + ssc.setBatchDuration(batchDuration) + + //warmup(ssc.sc) + + val data = ssc.sc.textFile(file, mapTasks.toInt).persist( + new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas + println("Data count: " + data.map(x => if (x == "") 1 else x.split(" ").size / x.split(" ").size).count()) + println("Data count: " + data.count()) + println("Data count: " + data.count()) + + val sentences = new ConstantInputDStream(ssc, data) + ssc.registerInputStream(sentences) + + import WordCount2_ExtraFunctions._ + + val windowedCounts = sentences + .mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) + windowedCounts.persist(StorageLevel.MEMORY_ONLY, + StorageLevel.MEMORY_ONLY_2, + //new StorageLevel(false, true, true, 3), + Milliseconds(chkptMillis.toLong)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + + while(true) { Thread.sleep(1000) } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala new file mode 100644 index 0000000000..591cb141c3 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala @@ -0,0 +1,26 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +object WordCountHdfs { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCountHdfs <master> <directory>") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "WordCountHdfs") + ssc.setBatchDuration(Seconds(2)) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.textFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala new file mode 100644 index 0000000000..ba1bd1de7c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala @@ -0,0 +1,25 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +object WordCountNetwork { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCountNetwork <master> <hostname> <port>") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "WordCountNetwork") + ssc.setBatchDuration(Seconds(2)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.networkTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala new file mode 100644 index 0000000000..abfd12890f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -0,0 +1,51 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.SparkContext +import spark.SparkContext._ +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +import WordCount2_ExtraFunctions._ + +object WordCountRaw { + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + + def main(args: Array[String]) { + if (args.length != 7) { + System.err.println("Usage: WordCountRaw <master> <streams> <host> <port> <batchMs> <chkptMs> <reduces>") + System.exit(1) + } + + val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), + IntParam(chkptMs), IntParam(reduces)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "WordCountRaw") + ssc.setBatchDuration(Milliseconds(batchMs)) + + // Make sure some tasks have started on each node + moreWarmup(ssc.sc) + + val rawStreams = (1 to streams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnionDStream(rawStreams) + + val windowedCounts = union.mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, + Milliseconds(chkptMs)) + //windowedCounts.print() // TODO: something else? + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala new file mode 100644 index 0000000000..9d44da2b11 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -0,0 +1,73 @@ +package spark.streaming.examples + +import spark.SparkContext +import SparkContext._ +import spark.streaming._ +import StreamingContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object WordMax2 { + + def warmup(sc: SparkContext) { + (0 until 10).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(x => (x % 337, x % 1331)) + .reduceByKey(_ + _) + .count() + } + } + + def main (args: Array[String]) { + + if (args.length != 6) { + println ("Usage: WordMax2 <host> <file> <mapTasks> <reduceTasks> <batchMillis> <chkptMillis>") + System.exit(1) + } + + val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args + + val batchDuration = Milliseconds(batchMillis.toLong) + + val ssc = new StreamingContext(master, "WordMax2") + ssc.setBatchDuration(batchDuration) + + //warmup(ssc.sc) + + val data = ssc.sc.textFile(file, mapTasks.toInt).persist( + new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas + println("Data count: " + data.count()) + println("Data count: " + data.count()) + println("Data count: " + data.count()) + + val sentences = new ConstantInputDStream(ssc, data) + ssc.registerInputStream(sentences) + + import WordCount2_ExtraFunctions._ + + val windowedCounts = sentences + .mapPartitions(splitAndCountPartitions) + .reduceByKey(add _, reduceTasks.toInt) + .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, + Milliseconds(chkptMillis.toLong)) + .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, + // Milliseconds(chkptMillis.toLong)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + + while(true) { Thread.sleep(1000) } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala new file mode 100644 index 0000000000..ed087e4ea8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala @@ -0,0 +1,84 @@ +package spark.streaming.util + +import spark.streaming._ + +trait Clock { + def currentTime(): Long + def waitTillTime(targetTime: Long): Long +} + + +class SystemClock() extends Clock { + + val minPollTime = 25L + + def currentTime(): Long = { + System.currentTimeMillis() + } + + def waitTillTime(targetTime: Long): Long = { + var currentTime = 0L + currentTime = System.currentTimeMillis() + + var waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + + val pollTime = { + if (waitTime / 10.0 > minPollTime) { + (waitTime / 10.0).toLong + } else { + minPollTime + } + } + + + while (true) { + currentTime = System.currentTimeMillis() + waitTime = targetTime - currentTime + + if (waitTime <= 0) { + + return currentTime + } + val sleepTime = + if (waitTime < pollTime) { + waitTime + } else { + pollTime + } + Thread.sleep(sleepTime) + } + return -1 + } +} + +class ManualClock() extends Clock { + + var time = 0L + + def currentTime() = time + + def setTime(timeToSet: Long) = { + this.synchronized { + time = timeToSet + this.notifyAll() + } + } + + def addToTime(timeToAdd: Long) = { + this.synchronized { + time += timeToAdd + this.notifyAll() + } + } + def waitTillTime(targetTime: Long): Long = { + this.synchronized { + while (time < targetTime) { + this.wait(100) + } + } + return currentTime() + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala b/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala new file mode 100644 index 0000000000..cde868a0c9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala @@ -0,0 +1,157 @@ +package spark.streaming.util + +import spark.Logging + +import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.nio.channels.spi._ + +abstract class ConnectionHandler(host: String, port: Int, connect: Boolean) +extends Thread with Logging { + + val selector = SelectorProvider.provider.openSelector() + val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + + initLogging() + + override def run() { + try { + if (connect) { + connect() + } else { + listen() + } + + var interrupted = false + while(!interrupted) { + + preSelect() + + while(!interestChangeRequests.isEmpty) { + val (key, ops) = interestChangeRequests.dequeue + val lastOps = key.interestOps() + key.interestOps(ops) + + def intToOpStr(op: Int): String = { + val opStrs = new ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + + selector.select() + interrupted = Thread.currentThread.isInterrupted + + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext) { + val key = selectedKeys.next.asInstanceOf[SelectionKey] + selectedKeys.remove() + if (key.isValid) { + if (key.isAcceptable) { + accept(key) + } else if (key.isConnectable) { + finishConnect(key) + } else if (key.isReadable) { + read(key) + } else if (key.isWritable) { + write(key) + } + } + } + } + } catch { + case e: Exception => { + logError("Error in select loop", e) + } + } + } + + def connect() { + val socketAddress = new InetSocketAddress(host, port) + val channel = SocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setTcpNoDelay(true) + channel.connect(socketAddress) + channel.register(selector, SelectionKey.OP_CONNECT) + logInfo("Initiating connection to [" + socketAddress + "]") + } + + def listen() { + val channel = ServerSocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setReceiveBufferSize(256 * 1024) + channel.socket.bind(new InetSocketAddress(port)) + channel.register(selector, SelectionKey.OP_ACCEPT) + logInfo("Listening on port " + port) + } + + def finishConnect(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val address = channel.socket.getRemoteSocketAddress + channel.finishConnect() + logInfo("Connected to [" + host + ":" + port + "]") + ready(key) + } catch { + case e: IOException => { + logError("Error finishing connect to " + host + ":" + port) + close(key) + } + } + } + + def accept(key: SelectionKey) { + try { + val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] + val channel = serverChannel.accept() + val address = channel.socket.getRemoteSocketAddress + channel.configureBlocking(false) + logInfo("Accepted connection from [" + address + "]") + ready(channel.register(selector, 0)) + } catch { + case e: IOException => { + logError("Error accepting connection", e) + } + } + } + + def changeInterest(key: SelectionKey, ops: Int) { + logTrace("Added request to change ops to " + ops) + interestChangeRequests += ((key, ops)) + } + + def ready(key: SelectionKey) + + def preSelect() { + } + + def read(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + } + + def write(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + } + + def close(key: SelectionKey) { + try { + key.channel.close() + key.cancel() + Thread.currentThread.interrupt + } catch { + case e: Exception => logError("Error closing connection", e) + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala new file mode 100644 index 0000000000..d8b987ec86 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -0,0 +1,60 @@ +package spark.streaming.util + +import java.nio.ByteBuffer +import spark.util.{RateLimitedOutputStream, IntParam} +import java.net.ServerSocket +import spark.{Logging, KryoSerializer} +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import io.Source +import java.io.IOException + +/** + * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a + * specified rate. Used to feed data into RawInputDStream. + */ +object RawTextSender extends Logging { + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println("Usage: RawTextSender <port> <file> <blockSize> <bytesPerSec>") + System.exit(1) + } + // Parse the arguments using a pattern match + val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args + + // Repeat the input data multiple times to fill in a buffer + val lines = Source.fromFile(file).getLines().toArray + val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) + val ser = new KryoSerializer().newInstance() + val serStream = ser.serializeStream(bufferStream) + var i = 0 + while (bufferStream.position < blockSize) { + serStream.writeObject(lines(i)) + i = (i + 1) % lines.length + } + bufferStream.trim() + val array = bufferStream.array + + val countBuf = ByteBuffer.wrap(new Array[Byte](4)) + countBuf.putInt(array.length) + countBuf.flip() + + val serverSocket = new ServerSocket(port) + logInfo("Listening on port " + port) + + while (true) { + val socket = serverSocket.accept() + logInfo("Got a new connection") + val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) + try { + while (true) { + out.write(countBuf.array) + out.write(array) + } + } catch { + case e: IOException => + logError("Client disconnected") + socket.close() + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala new file mode 100644 index 0000000000..dc55fd902b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -0,0 +1,73 @@ +package spark.streaming.util + +class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { + + val minPollTime = 25L + + val pollTime = { + if (period / 10.0 > minPollTime) { + (period / 10.0).toLong + } else { + minPollTime + } + } + + val thread = new Thread() { + override def run() { loop } + } + + var nextTime = 0L + + def start(startTime: Long): Long = { + nextTime = startTime + thread.start() + nextTime + } + + def start(): Long = { + val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period + start(startTime) + } + + def restart(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime + start(newStartTime) + } + + def stop() { + thread.interrupt() + } + + def loop() { + try { + while (true) { + clock.waitTillTime(nextTime) + callback(nextTime) + nextTime += period + } + + } catch { + case e: InterruptedException => + } + } +} + +object RecurringTimer { + + def main(args: Array[String]) { + var lastRecurTime = 0L + val period = 1000 + + def onRecur(time: Long) { + val currentTime = System.currentTimeMillis() + println("" + currentTime + ": " + (currentTime - lastRecurTime)) + lastRecurTime = currentTime + } + val timer = new RecurringTimer(new SystemClock(), period, onRecur) + timer.start() + Thread.sleep(30 * 1000) + timer.stop() + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala new file mode 100644 index 0000000000..3922dfbad6 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala @@ -0,0 +1,67 @@ +package spark.streaming.util + +import java.net.{Socket, ServerSocket} +import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} + +object Receiver { + def main(args: Array[String]) { + val port = args(0).toInt + val lsocket = new ServerSocket(port) + println("Listening on port " + port ) + while(true) { + val socket = lsocket.accept() + (new Thread() { + override def run() { + val buffer = new Array[Byte](100000) + var count = 0 + val time = System.currentTimeMillis + try { + val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) + var loop = true + var string: String = null + do { + string = is.readUTF() + if (string != null) { + count += 28 + } + } while (string != null) + } catch { + case e: Exception => e.printStackTrace() + } + val timeTaken = System.currentTimeMillis - time + val tput = (count / 1024.0) / (timeTaken / 1000.0) + println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") + } + }).start() + } + } + +} + +object Sender { + + def main(args: Array[String]) { + try { + val host = args(0) + val port = args(1).toInt + val size = args(2).toInt + + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) + val bytes = byteStream.toByteArray() + println("Generated array of " + bytes.length + " bytes") + + /*val bytes = new Array[Byte](size)*/ + val socket = new Socket(host, port) + val os = socket.getOutputStream + os.write(bytes) + os.flush + socket.close() + + } catch { + case e: Exception => e.printStackTrace + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala new file mode 100644 index 0000000000..94e8f7a849 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala @@ -0,0 +1,92 @@ +package spark.streaming.util + +import spark._ + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.io.Source + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +object SentenceFileGenerator { + + def printUsage () { + println ("Usage: SentenceFileGenerator <master> <target directory> <# partitions> <sentence file> [<sentences per second>]") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val master = args(0) + val fs = new Path(args(1)).getFileSystem(new Configuration()) + val targetDirectory = new Path(args(1)).makeQualified(fs) + val numPartitions = args(2).toInt + val sentenceFile = args(3) + val sentencesPerSecond = { + if (args.length > 4) args(4).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n").toArray + source.close () + println("Read " + lines.length + " lines from file " + sentenceFile) + + val sentences = { + val buffer = ArrayBuffer[String]() + val random = new Random() + var i = 0 + while (i < sentencesPerSecond) { + buffer += lines(random.nextInt(lines.length)) + i += 1 + } + buffer.toArray + } + println("Generated " + sentences.length + " sentences") + + val sc = new SparkContext(master, "SentenceFileGenerator") + val sentencesRDD = sc.parallelize(sentences, numPartitions) + + val tempDirectory = new Path(targetDirectory, "_tmp") + + fs.mkdirs(targetDirectory) + fs.mkdirs(tempDirectory) + + var saveTimeMillis = System.currentTimeMillis + try { + while (true) { + val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) + val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) + println("Writing to file " + newDir) + sentencesRDD.saveAsTextFile(tmpNewDir.toString) + fs.rename(tmpNewDir, newDir) + saveTimeMillis += 1000 + val sleepTimeMillis = { + val currentTimeMillis = System.currentTimeMillis + if (saveTimeMillis < currentTimeMillis) { + 0 + } else { + saveTimeMillis - currentTimeMillis + } + } + println("Sleeping for " + sleepTimeMillis + " ms") + Thread.sleep(sleepTimeMillis) + } + } catch { + case e: Exception => + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala new file mode 100644 index 0000000000..60085f4f88 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala @@ -0,0 +1,23 @@ +package spark.streaming.util + +import spark.SparkContext +import SparkContext._ + +object ShuffleTest { + def main(args: Array[String]) { + + if (args.length < 1) { + println ("Usage: ShuffleTest <host>") + System.exit(1) + } + + val sc = new spark.SparkContext(args(0), "ShuffleTest") + val rdd = sc.parallelize(1 to 1000, 500).cache + + def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } + + time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } + System.exit(0) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala new file mode 100644 index 0000000000..23e9235c60 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala @@ -0,0 +1,107 @@ +package spark.streaming.util + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + + +object TestGenerator { + + def printUsage { + println ("Usage: SentenceGenerator <target IP> <target port> <sentence file> [<sentences per second>]") + System.exit(0) + } + /* + def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + + try { + var lastPrintTime = System.currentTimeMillis() + var count = 0 + while(true) { + streamReceiver ! lines(random.nextInt(lines.length)) + count += 1 + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + Thread.sleep(sleepBetweenSentences.toLong) + } + } catch { + case e: Exception => + } + }*/ + + def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + try { + val numSentences = if (sentencesPerSecond <= 0) { + lines.length + } else { + sentencesPerSecond + } + val sentences = lines.take(numSentences).toArray + + var nextSendingTime = System.currentTimeMillis() + val sendAsArray = true + while(true) { + if (sendAsArray) { + println("Sending as array") + streamReceiver !? sentences + } else { + println("Sending individually") + sentences.foreach(sentence => { + streamReceiver !? sentence + }) + } + println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms") + nextSendingTime += 1000 + val sleepTime = nextSendingTime - System.currentTimeMillis + if (sleepTime > 0) { + println ("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } + } + } catch { + case e: Exception => + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val generateRandomly = false + + val streamReceiverIP = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 + val sentenceInputName = if (args.length > 4) args(4) else "Sentences" + + println("Sending " + sentencesPerSecond + " sentences per second to " + + streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + val streamReceiver = select( + Node(streamReceiverIP, streamReceiverPort), + Symbol("NetworkStreamReceiver-" + sentenceInputName)) + if (generateRandomly) { + /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/ + } else { + generateSameSentences(lines, sentencesPerSecond, streamReceiver) + } + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala new file mode 100644 index 0000000000..ff840d084f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala @@ -0,0 +1,119 @@ +package spark.streaming.util + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream} +import java.net.Socket + +object TestGenerator2 { + + def printUsage { + println ("Usage: SentenceGenerator <target IP> <target port> <sentence file> [<sentences per second>]") + System.exit(0) + } + + def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){ + try { + println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort) + val socket = new Socket(streamReceiverHost, streamReceiverPort) + + println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime + Thread.sleep(targetTime - currentTime) + + while(true) { + val startTime = System.currentTimeMillis() + println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + val socketOutputStream = socket.getOutputStream + val parts = 10 + (0 until parts).foreach(i => { + val partStartTime = System.currentTimeMillis + + val offset = (i * bytes.length / parts).toInt + val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length) + socketOutputStream.write(bytes, offset, len) + socketOutputStream.flush() + val partFinishTime = System.currentTimeMillis + println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms") + val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + }) + + socketOutputStream.flush() + /*val socketInputStream = new DataInputStream(socket.getInputStream)*/ + /*val reply = socketInputStream.readUTF()*/ + val finishTime = System.currentTimeMillis() + println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]") + /*println("Received = " + reply)*/ + targetTime = targetTime + intervalTime + val sleepTime = (targetTime - finishTime) + 10 + if (sleepTime > 0) { + println("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + println("############################") + println("###### Skipping sleep ######") + println("############################") + } + } + } catch { + case e: Exception => println(e) + } + println("Stopped sending") + } + + def main(args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalTime = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + println("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + println("Generating sentences") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + println("Converting to byte array") + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + /*stringDataStream.writeInt(sentences.size)*/ + sentences.foreach(stringDataStream.writeUTF) + val bytes = byteStream.toByteArray() + stringDataStream.close() + println("Generated array of " + bytes.length + " bytes") + + /*while(true) { */ + sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime) + /*println("Sleeping for 5 seconds")*/ + /*Thread.sleep(5000)*/ + /*System.gc()*/ + /*}*/ + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala new file mode 100644 index 0000000000..9c39ef3e12 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala @@ -0,0 +1,244 @@ +package spark.streaming.util + +import spark.Logging + +import scala.util.Random +import scala.io.Source +import scala.collection.mutable.{ArrayBuffer, Queue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ + +import it.unimi.dsi.fastutil.io._ + +class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int) +extends Logging { + + class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4) + extends ConnectionHandler(host, port, true) { + + val buffers = new ArrayBuffer[ByteBuffer] + val newBuffers = new Queue[ByteBuffer] + var activeKey: SelectionKey = null + + def send(buffer: ByteBuffer) { + logDebug("Sending: " + buffer) + newBuffers.synchronized { + newBuffers.enqueue(buffer) + } + selector.wakeup() + buffer.synchronized { + buffer.wait() + } + } + + override def ready(key: SelectionKey) { + logDebug("Ready") + activeKey = key + val channel = key.channel.asInstanceOf[SocketChannel] + channel.register(selector, SelectionKey.OP_WRITE) + generator.startSending() + } + + override def preSelect() { + newBuffers.synchronized { + while(!newBuffers.isEmpty) { + val buffer = newBuffers.dequeue + buffers += buffer + logDebug("Added: " + buffer) + changeInterest(activeKey, SelectionKey.OP_WRITE) + } + } + } + + override def write(key: SelectionKey) { + try { + /*while(true) {*/ + val channel = key.channel.asInstanceOf[SocketChannel] + if (buffers.size > 0) { + val buffer = buffers(0) + val newBuffer = buffer.slice() + newBuffer.limit(math.min(newBuffer.remaining, 32768)) + val bytesWritten = channel.write(newBuffer) + buffer.position(buffer.position + bytesWritten) + if (bytesWritten == 0) return + if (buffer.remaining == 0) { + buffers -= buffer + buffer.synchronized { + buffer.notify() + } + } + /*changeInterest(key, SelectionKey.OP_WRITE)*/ + } else { + changeInterest(key, 0) + } + /*}*/ + } catch { + case e: IOException => { + if (e.toString.contains("pipe") || e.toString.contains("reset")) { + logError("Connection broken") + } else { + logError("Connection error", e) + } + close(key) + } + } + } + + override def close(key: SelectionKey) { + buffers.clear() + super.close(key) + } + } + + initLogging() + + val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this) + var sendingThread: Thread = null + var sendCount = 0 + val sendBatches = 5 + + def run() { + logInfo("Connection handler started") + connectionHandler.start() + connectionHandler.join() + if (sendingThread != null && !sendingThread.isInterrupted) { + sendingThread.interrupt + } + logInfo("Connection handler stopped") + } + + def startSending() { + sendingThread = new Thread() { + override def run() { + logInfo("STARTING TO SEND") + sendSentences() + logInfo("SENDING STOPPED AFTER " + sendCount) + connectionHandler.interrupt() + } + } + sendingThread.start() + } + + def stopSending() { + sendingThread.interrupt() + } + + def sendSentences() { + logInfo("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + logInfo("Generating sentence buffer") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + /* + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take((numSentences / sendBatches).toInt).toArray + } else { + (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray + }*/ + + + val serializer = new spark.KryoSerializer().newInstance() + val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024) + serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close() + byteStream.trim() + val sentenceBuffer = ByteBuffer.wrap(byteStream.array) + + logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration + Thread.sleep(targetTime - currentTime) + + val totalBytes = sentenceBuffer.limit + + while(true) { + val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches + + val startTime = System.currentTimeMillis() + logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + + (0 until batchesInCurrentInterval).foreach(i => { + try { + val position = (i * totalBytes / sendBatches).toInt + val limit = if (i == sendBatches - 1) { + totalBytes + } else { + ((i + 1) * totalBytes / sendBatches).toInt - 1 + } + + val partStartTime = System.currentTimeMillis + sentenceBuffer.limit(limit) + connectionHandler.send(sentenceBuffer) + val partFinishTime = System.currentTimeMillis + val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + + } catch { + case ie: InterruptedException => return + case e: Exception => e.printStackTrace() + } + }) + sentenceBuffer.rewind() + + val finishTime = System.currentTimeMillis() + /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/ + targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0) + + val sleepTime = (targetTime - finishTime) + 20 + if (sleepTime > 0) { + logInfo("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + logInfo("###### Skipping sleep ######") + } + if (Thread.currentThread.isInterrupted) { + return + } + sendCount += 1 + } + } +} + +object TestGenerator4 { + def printUsage { + println("Usage: TestGenerator4 <target IP> <target port> <sentence file> <interval duration> [<sentences per second>]") + System.exit(0) + } + + def main(args: Array[String]) { + println("GENERATOR STARTED") + if (args.length < 4) { + printUsage + } + + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalDuration = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + while(true) { + val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval) + generator.run() + Thread.sleep(2000) + } + println("GENERATOR STOPPED") + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala new file mode 100644 index 0000000000..f584f772bb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala @@ -0,0 +1,39 @@ +package spark.streaming.util + +import spark.streaming._ +import spark.Logging + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ + +sealed trait TestStreamCoordinatorMessage +case class GetStreamDetails extends TestStreamCoordinatorMessage +case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage +case class TestStarted extends TestStreamCoordinatorMessage + +class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging { + + var index = 0 + + initLogging() + + logInfo("Created") + + def receive = { + case TestStarted => { + sender ! "OK" + } + + case GetStreamDetails => { + val streamDetail = if (index >= streamDetails.length) null else streamDetails(index) + sender ! GotStreamDetails(streamDetail._1, streamDetail._2) + index += 1 + if (streamDetail != null) { + logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" ) + } + } + } + +} + diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala new file mode 100644 index 0000000000..80ad924dd8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala @@ -0,0 +1,421 @@ +package spark.streaming.util + +import spark._ +import spark.storage._ +import spark.util.AkkaUtils +import spark.streaming._ + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +import java.io.DataInputStream +import java.io.BufferedInputStream +import java.net.Socket +import java.net.ServerSocket +import java.util.LinkedHashMap + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +import spark.Utils + + +class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + + class DataHandler( + inputName: String, + longIntervalDuration: Time, + shortIntervalDuration: Time, + blockManager: BlockManager + ) + extends Logging { + + class Block(var id: String, var shortInterval: Interval) { + val data = ArrayBuffer[String]() + var pushed = false + def longInterval = getLongInterval(shortInterval) + def empty() = (data.size == 0) + def += (str: String) = (data += str) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val shortIntervalDurationMillis = shortIntervalDuration.toLong + val longIntervalDurationMillis = longIntervalDuration.toLong + + var currentBlock: Block = null + var currentBucket: Bucket = null + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } } + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + def start() { + blockUpdatingThread.start() + blockPushingThread.start() + } + + def += (data: String) = addData(data) + + def addData(data: String) { + if (currentBlock == null) { + updateCurrentBlock() + } + currentBlock.synchronized { + currentBlock += data + } + } + + def getShortInterval(time: Time): Interval = { + val intervalBegin = time.floor(shortIntervalDuration) + Interval(intervalBegin, intervalBegin + shortIntervalDuration) + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def updateCurrentBlock() { + /*logInfo("Updating current block")*/ + val currentTime = Time(System.currentTimeMillis) + val shortInterval = getShortInterval(currentTime) + val longInterval = getLongInterval(shortInterval) + + def createBlock(reuseCurrentBlock: Boolean = false) { + val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size + if (!reuseCurrentBlock) { + val newBlock = new Block(newBlockId, shortInterval) + /*logInfo("Created " + currentBlock)*/ + currentBlock = newBlock + } else { + currentBlock.shortInterval = shortInterval + currentBlock.id = newBlockId + } + } + + def createBucket() { + val newBucket = new Bucket(longInterval) + buckets += ((longInterval, newBucket)) + currentBucket = newBucket + /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/ + } + + if (currentBlock == null || currentBucket == null) { + createBucket() + currentBucket.synchronized { + createBlock() + } + return + } + + currentBlock.synchronized { + var reuseCurrentBlock = false + + if (shortInterval != currentBlock.shortInterval) { + if (!currentBlock.empty) { + blocksForPushing.synchronized { + blocksForPushing += currentBlock + blocksForPushing.notifyAll() + } + } + + currentBucket.synchronized { + if (currentBlock.empty) { + reuseCurrentBlock = true + } else { + currentBucket += currentBlock + } + + if (longInterval != currentBucket.longInterval) { + currentBucket.filled = true + if (currentBucket.ready) { + currentBucket.notifyAll() + } + createBucket() + } + } + + createBlock(reuseCurrentBlock) + } + } + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + logInfo("Pushing block") + val startTime = System.currentTimeMillis + + val bytes = blockManager.dataSerialize("rdd_", block.data.toIterator) // TODO: Will this be an RDD block? + val finishTime = System.currentTimeMillis + logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s") + + blockManager.putBytes(block.id.toString, bytes, StorageLevel.MEMORY_AND_DISK_SER_2) + /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/ + val finishTime1 = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepUpdatingCurrentBlock() { + logInfo("Thread to update current block started") + while(true) { + updateCurrentBlock() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def keepPushingBlocks() { + var loop = true + logInfo("Thread to push blocks started") + while(loop) { + val block = blocksForPushing.synchronized { + if (blocksForPushing.size == 0) { + blocksForPushing.wait() + } + blocksForPushing.dequeue + } + pushBlock(block) + block.pushed = true + block.data.clear() + + val bucket = buckets(block.longInterval) + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + } + } + } + + + class ConnectionListener(port: Int, dataHandler: DataHandler) + extends Thread with Logging { + initLogging() + override def run { + try { + val listener = new ServerSocket(port) + logInfo("Listening on port " + port) + while (true) { + new ConnectionHandler(listener.accept(), dataHandler).start(); + } + listener.close() + } catch { + case e: Exception => logError("", e); + } + } + } + + class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging { + initLogging() + override def run { + logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort) + val bytes = new Array[Byte](100 * 1024 * 1024) + try { + + val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024)) + /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/ + var str: String = null + str = inputStream.readUTF + while(str != null) { + dataHandler += str + str = inputStream.readUTF() + } + + /* + var loop = true + while(loop) { + val numRead = inputStream.read(bytes) + if (numRead < 0) { + loop = false + } + inbox += ((LongTime(SystemTime.currentTimeMillis), "test")) + }*/ + + inputStream.close() + } catch { + case e => logError("Error receiving data", e) + } + socket.close() + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host") + val masterPort = System.getProperty("spark.master.port").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val timeout = 50 millis + + var started = false + while (!started) { + askActor[String](testStreamCoordinator, TestStarted) match { + case Some(str) => { + started = true + logInfo("TestStreamCoordinator started") + } + case None => { + logInfo("TestStreamCoordinator not started yet") + Thread.sleep(200) + } + } + } + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = Time(intervalDurationMillis) + + val dataHandler = new DataHandler( + inputName, + intervalDuration, + Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), + blockManager) + + val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connListener.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + while(true) { + waitFor(interval.endTime) + logInfo("Woken up at " + System.currentTimeMillis + " for " + interval) + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logInfo("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logInfo("Waiting for " + bucket) + bucket.wait() + logInfo("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logInfo("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logInfo("Found none for " + interval) + if (dataStarted) { + logInfo("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime + val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 + logInfo("Pushing delay for " + time + " is " + delay + " s") + } catch { + case _ => logError("Exception notifying scheduler at interval " + interval) + } + } +} + +object TestStreamReceiver3 { + + val PORT = 9999 + val SHORT_INTERVAL_MILLIS = 100 + + def main(args: Array[String]) { + System.setProperty("spark.master.host", Utils.localHostName) + System.setProperty("spark.master.port", "7078") + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver3(actorSystem, null).start() + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala new file mode 100644 index 0000000000..31754870dd --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala @@ -0,0 +1,374 @@ +package spark.streaming.util + +import spark.streaming._ +import spark._ +import spark.storage._ +import spark.util.AkkaUtils + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.util.concurrent.Executors + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + class DataHandler( + inputName: String, + longIntervalDuration: Time, + shortIntervalDuration: Time, + blockManager: BlockManager + ) + extends Logging { + + class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) { + var pushed = false + def longInterval = getLongInterval(shortInterval) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val syncOnLastShortInterval = true + + val shortIntervalDurationMillis = shortIntervalDuration.milliseconds + val longIntervalDurationMillis = longIntervalDuration.milliseconds + + val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) + var currentShortInterval = Interval.currentInterval(shortIntervalDuration) + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } } + val blockPushingExecutor = Executors.newFixedThreadPool(5) + + + def start() { + buffer.clear() + if (buffer.remaining == 0) { + throw new Exception("Buffer initialization error") + } + bufferProcessingThread.start() + } + + def readDataToBuffer(func: ByteBuffer => Int): Int = { + buffer.synchronized { + if (buffer.remaining == 0) { + logInfo("Received first data for interval " + currentShortInterval) + } + func(buffer) + } + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def processBuffer() { + + def readInt(buffer: ByteBuffer): Int = { + var offset = 0 + var result = 0 + while (offset < 32) { + val b = buffer.get() + result |= ((b & 0x7F) << offset) + if ((b & 0x80) == 0) { + return result + } + offset += 7 + } + throw new Exception("Malformed zigzag-encoded integer") + } + + val currentLongInterval = getLongInterval(currentShortInterval) + val startTime = System.currentTimeMillis + val newBuffer: ByteBuffer = buffer.synchronized { + buffer.flip() + if (buffer.remaining == 0) { + buffer.clear() + null + } else { + logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms") + val startTime1 = System.currentTimeMillis + var loop = true + var count = 0 + while(loop) { + buffer.mark() + try { + val len = readInt(buffer) + buffer.position(buffer.position + len) + count += 1 + } catch { + case e: Exception => { + buffer.reset() + loop = false + } + } + } + val bytesToCopy = buffer.position + val newBuf = ByteBuffer.allocate(bytesToCopy) + buffer.position(0) + newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer]) + newBuf.flip() + buffer.position(bytesToCopy) + buffer.compact() + newBuf + } + } + + if (newBuffer != null) { + val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval)) + bucket.synchronized { + val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString + val newBlock = new Block(newBlockId, currentShortInterval, newBuffer) + if (syncOnLastShortInterval) { + bucket += newBlock + } + logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" ) + blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) + } + } + + val newShortInterval = Interval.currentInterval(shortIntervalDuration) + val newLongInterval = getLongInterval(newShortInterval) + + if (newLongInterval != currentLongInterval) { + buckets.get(currentLongInterval) match { + case Some(bucket) => { + bucket.synchronized { + bucket.filled = true + if (bucket.ready) { + bucket.notifyAll() + } + } + } + case None => + } + buckets += ((newLongInterval, new Bucket(newLongInterval))) + } + + currentShortInterval = newShortInterval + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + val startTime = System.currentTimeMillis + logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms") + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ + blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + val finishTime = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime - startTime) + " ms") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepProcessingBuffers() { + logInfo("Thread to process buffers started") + while(true) { + processBuffer() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def pushAndNotifyBlock(block: Block) { + pushBlock(block) + block.pushed = true + val bucket = if (syncOnLastShortInterval) { + buckets(block.longInterval) + } else { + var longInterval = block.longInterval + while(!buckets.contains(longInterval)) { + logWarning("Skipping bucket of " + longInterval + " for " + block) + longInterval = longInterval.next + } + val chosenBucket = buckets(longInterval) + logDebug("Choosing bucket of " + longInterval + " for " + block) + chosenBucket += block + chosenBucket + } + + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + + } + } + + + class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler) + extends ConnectionHandler(host, port, false) { + + override def ready(key: SelectionKey) { + changeInterest(key, SelectionKey.OP_READ) + } + + override def read(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val bytesRead = dataHandler.readDataToBuffer(channel.read) + if (bytesRead < 0) { + close(key) + } + } catch { + case e: IOException => { + logError("Error reading", e) + close(key) + } + } + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host", "localhost") + val masterPort = System.getProperty("spark.master.port", "7078").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = Milliseconds(intervalDurationMillis) + val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt) + + val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager) + val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler) + + val timeout = 100 millis + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connectionHandler.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + + while(true) { + waitFor(interval.endTime) + /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/ + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logDebug("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logDebug("Waiting for " + bucket) + bucket.wait() + logDebug("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logDebug("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logDebug("Found none for " + interval) + if (dataStarted) { + logDebug("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime + val delay = (System.currentTimeMillis - time.milliseconds) + logInfo("Notification delay for " + time + " is " + delay + " ms") + } catch { + case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e) + } + } +} + + +object TestStreamReceiver4 { + def main(args: Array[String]) { + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver4(actorSystem, null).start() + } +} diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties new file mode 100644 index 0000000000..02fe16866e --- /dev/null +++ b/streaming/src/test/resources/log4j.properties @@ -0,0 +1,8 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala new file mode 100644 index 0000000000..d0aaac0f2e --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -0,0 +1,213 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ +import scala.runtime.RichInt +import util.ManualClock + +class BasicOperationsSuite extends TestSuiteBase { + + override def framework() = "BasicOperationsSuite" + + test("map") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.map(_.toString), + input.map(_.map(_.toString)) + ) + } + + test("flatmap") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), + input.map(_.flatMap(x => Array(x, x * 2))) + ) + } + + test("filter") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.filter(x => (x % 2 == 0)), + input.map(_.filter(x => (x % 2 == 0))) + ) + } + + test("glom") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq( + Seq( Seq(1, 2), Seq(3, 4) ), + Seq( Seq(5, 6), Seq(7, 8) ), + Seq( Seq(9, 10), Seq(11, 12) ) + ) + val operation = (r: DStream[Int]) => r.glom().map(_.toSeq) + testOperation(input, operation, output) + } + + test("mapPartitions") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23)) + val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _))) + testOperation(input, operation, output, true) + } + + test("groupByKey") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(), + Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ), + true + ) + } + + test("reduceByKey") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + true + ) + } + + test("reduce") { + testOperation( + Seq(1 to 4, 5 to 8, 9 to 12), + (s: DStream[Int]) => s.reduce(_ + _), + Seq(Seq(10), Seq(26), Seq(42)) + ) + } + + test("mapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10), + Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ), + true + ) + } + + test("flatMapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), + true + ) + } + + test("cogroup") { + val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) + val outputData = Seq( + Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ), + Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ), + Seq( ("", (Seq(1), Seq())) ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + + test("join") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) + val outputData = Seq( + Seq( ("a", (1, "x")), ("b", (1, "x")) ), + Seq( ("", (1, "x")) ), + Seq( ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).join(s2.map(x => (x,"x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + + test("updateStateByKey") { + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val updateStateOperation = (s: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + + test("forgetting of RDDs - map and window operations") { + assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") + + val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq + val rememberDuration = Seconds(3) + + assert(input.size === 10, "Number of inputs have changed") + + def operation(s: DStream[Int]): DStream[(Int, Int)] = { + s.map(x => (x % 10, 1)) + .window(Seconds(2), Seconds(1)) + .window(Seconds(4), Seconds(2)) + } + + val ssc = setupStreams(input, operation _) + ssc.setRememberDuration(rememberDuration) + runStreams[(Int, Int)](ssc, input.size, input.size / 2) + + val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head + val windowedStream1 = windowedStream2.dependencies.head + val mappedStream = windowedStream1.dependencies.head + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + assert(clock.time === Seconds(10).milliseconds) + + // IDEALLY + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, + + // IN THIS TEST + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 + + // WindowedStream2 + assert(windowedStream2.generatedRDDs.contains(Seconds(10))) + assert(windowedStream2.generatedRDDs.contains(Seconds(8))) + assert(!windowedStream2.generatedRDDs.contains(Seconds(6))) + + // WindowedStream1 + assert(windowedStream1.generatedRDDs.contains(Seconds(10))) + assert(windowedStream1.generatedRDDs.contains(Seconds(4))) + assert(!windowedStream1.generatedRDDs.contains(Seconds(3))) + + // MappedStream + assert(mappedStream.generatedRDDs.contains(Seconds(10))) + assert(mappedStream.generatedRDDs.contains(Seconds(2))) + assert(!mappedStream.generatedRDDs.contains(Seconds(1))) + } +} diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala new file mode 100644 index 0000000000..6dcedcf463 --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -0,0 +1,53 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ +import java.io.File + +class CheckpointSuite extends TestSuiteBase { + + override def framework() = "CheckpointSuite" + + override def checkpointFile() = "checkpoint" + + def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + + val totalNumBatches = input.size + val initialNumBatches = input.size / 2 + val nextNumBatches = totalNumBatches - initialNumBatches + val initialNumExpectedOutputs = initialNumBatches + + // Do half the computation (half the number of batches), create checkpoint file and quit + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + Thread.sleep(1000) + + // Restart and complete the computation from checkpoint file + val sscNew = new StreamingContext(checkpointFile) + sscNew.setCheckpointDetails(null, null) + val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) + verifyOutput[V](outputNew, expectedOutput, useSet) + + new File(checkpointFile).delete() + new File(checkpointFile + ".bk").delete() + new File("." + checkpointFile + ".crc").delete() + new File("." + checkpointFile + ".bk.crc").delete() + } + + test("simple per-batch operation") { + testCheckpointedOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + true + ) + } +}
\ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala new file mode 100644 index 0000000000..f81ab2607f --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -0,0 +1,114 @@ +package spark.streaming + +import java.net.{SocketException, Socket, ServerSocket} +import java.io.{BufferedWriter, OutputStreamWriter} +import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.ManualClock +import spark.storage.StorageLevel +import spark.Logging + + +class InputStreamsSuite extends TestSuiteBase { + + test("network input stream") { + val serverPort = 9999 + val server = new TestServer(9999) + server.start() + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + + val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3) + val expectedOutput = input.map(_.toString) + for (i <- 0 until input.size) { + server.send(input(i).toString + "\n") + Thread.sleep(1000) + clock.addToTime(1000) + } + val startTime = System.currentTimeMillis() + while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) + Thread.sleep(100) + } + Thread.sleep(5000) + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping server") + server.stop() + logInfo("Stopping context") + ssc.stop() + + assert(outputBuffer.size === expectedOutput.size) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + assert(outputBuffer(i).head === expectedOutput(i)) + } + } +} + + +class TestServer(port: Int) extends Logging { + + val queue = new ArrayBlockingQueue[String](100) + + val serverSocket = new ServerSocket(port) + + val servingThread = new Thread() { + override def run() { + try { + while(true) { + logInfo("Accepting connections on port " + port) + val clientSocket = serverSocket.accept() + logInfo("New connection") + try { + clientSocket.setTcpNoDelay(true) + val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + + while(clientSocket.isConnected) { + val msg = queue.poll(100, TimeUnit.MILLISECONDS) + if (msg != null) { + outputStream.write(msg) + outputStream.flush() + logInfo("Message '" + msg + "' sent") + } + } + } catch { + case e: SocketException => println(e) + } finally { + logInfo("Connection closed") + if (!clientSocket.isClosed) clientSocket.close() + } + } + } catch { + case ie: InterruptedException => + + } finally { + serverSocket.close() + } + } + } + + def start() { servingThread.start() } + + def send(msg: String) { queue.add(msg) } + + def stop() { servingThread.interrupt() } +} + +object TestServer { + def main(args: Array[String]) { + val s = new TestServer(9999) + s.start() + while(true) { + Thread.sleep(1000) + s.send("hello") + } + } +} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala new file mode 100644 index 0000000000..c1b7772e7b --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -0,0 +1,216 @@ +package spark.streaming + +import spark.{RDD, Logging} +import util.ManualClock +import collection.mutable.ArrayBuffer +import org.scalatest.FunSuite +import collection.mutable.SynchronizedBuffer + +class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) + extends InputDStream[T](ssc_) { + var currentIndex = 0 + + def start() {} + + def stop() {} + + def compute(validTime: Time): Option[RDD[T]] = { + logInfo("Computing RDD for time " + validTime) + val rdd = if (currentIndex < input.size) { + ssc.sc.makeRDD(input(currentIndex), numPartitions) + } else { + ssc.sc.makeRDD(Seq[T](), numPartitions) + } + logInfo("Created RDD " + rdd.id) + currentIndex += 1 + Some(rdd) + } +} + +class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) + extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) + +trait TestSuiteBase extends FunSuite with Logging { + + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + + def framework() = "TestSuiteBase" + + def master() = "local[2]" + + def batchDuration() = Seconds(1) + + def checkpointFile() = null.asInstanceOf[String] + + def checkpointInterval() = batchDuration + + def numInputPartitions() = 2 + + def maxWaitTimeMillis() = 10000 + + def setupStreams[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + if (checkpointFile != null) { + ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + } + + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input, numInputPartitions) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + ssc.registerInputStream(inputStream) + ssc.registerOutputStream(outputStream) + ssc + } + + def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + if (checkpointFile != null) { + ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + } + + // Setup the stream computation + val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) + val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) + val operatedStream = operation(inputStream1, inputStream2) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) + ssc.registerInputStream(inputStream1) + ssc.registerInputStream(inputStream2) + ssc.registerOutputStream(outputStream) + ssc + } + + + def runStreams[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + numExpectedOutput: Int + ): Seq[Seq[V]] = { + + assert(numBatches > 0, "Number of batches to run stream computation is zero") + assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") + logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) + + // Get the output buffer + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val output = outputStream.output + + try { + // Start computation + ssc.start() + + // Advance manual clock + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + clock.addToTime(numBatches * batchDuration.milliseconds) + logInfo("Manual clock after advancing = " + clock.time) + + // Wait until expected number of output items have been generated + val startTime = System.currentTimeMillis() + while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) + Thread.sleep(100) + } + val timeTaken = System.currentTimeMillis() - startTime + + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") + + Thread.sleep(500) // Give some time for the forgetting old RDDs to complete + } catch { + case e: Exception => e.printStackTrace(); throw e; + } finally { + ssc.stop() + } + + output + } + + def verifyOutput[V: ClassManifest]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + useSet: Boolean + ) { + logInfo("--------------------------------") + logInfo("output.size = " + output.size) + logInfo("output") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Match the output with the expected output + assert(output.size === expectedOutput.size, "Number of outputs do not match") + for (i <- 0 until output.size) { + if (useSet) { + assert(output(i).toSet === expectedOutput(i).toSet) + } else { + assert(output(i).toList === expectedOutput(i).toList) + } + } + logInfo("Output verified successfully") + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + testOperation[U, V](input, operation, expectedOutput, -1, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatches: Int, + useSet: Boolean + ) { + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + useSet: Boolean + ) { + testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + numBatches: Int, + useSet: Boolean + ) { + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V, W](input1, input2, operation) + val output = runStreams[W](ssc, numBatches_, expectedOutput.size) + verifyOutput[W](output, expectedOutput, useSet) + } +} diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala new file mode 100644 index 0000000000..90d67844bb --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -0,0 +1,188 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ + +class WindowOperationsSuite extends TestSuiteBase { + + override def framework() = "WindowOperationsSuite" + + override def maxWaitTimeMillis() = 20000 + + val largerSlideInput = Seq( + Seq(("a", 1)), + Seq(("a", 2)), // 1st window from here + Seq(("a", 3)), + Seq(("a", 4)), // 2nd window from here + Seq(("a", 5)), + Seq(("a", 6)), // 3rd window from here + Seq(), + Seq() // 4th window from here + ) + + val largerSlideOutput = Seq( + Seq(("a", 3)), + Seq(("a", 10)), + Seq(("a", 18)), + Seq(("a", 11)) + ) + + + val bigInput = Seq( + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq(), + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq() + ) + + val bigOutput = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)) + ) + + /* + The output of the reduceByKeyAndWindow with inverse reduce function is + difference from the naive reduceByKeyAndWindow. Even if the count of a + particular key is 0, the key does not get eliminated from the RDDs of + ReducedWindowedDStream. This causes the number of keys in these RDDs to + increase forever. A more generalized version that allows elimination of + keys should be considered. + */ + val bigOutputInv = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)) + ) + + def testReduceByKeyAndWindow( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration + ) { + test("reduceByKeyAndWindow - " + name) { + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + def testReduceByKeyAndWindowInv( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration + ) { + test("reduceByKeyAndWindowInv - " + name) { + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + + // Testing naive reduceByKeyAndWindow (without invertible function) + + testReduceByKeyAndWindow( + "basic reduction", + Seq( Seq(("a", 1), ("a", 3)) ), + Seq( Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindow( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindow( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindow( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() ) + ) + + testReduceByKeyAndWindow( + "larger slide time", + largerSlideInput, + largerSlideOutput, + Seconds(4), + Seconds(2) + ) + + testReduceByKeyAndWindow("big test", bigInput, bigOutput) + + + // Testing reduceByKeyAndWindow (with invertible reduce function) + + testReduceByKeyAndWindowInv( + "basic reduction", + Seq(Seq(("a", 1), ("a", 3)) ), + Seq(Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindowInv( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindowInv( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindowInv( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) + ) + + testReduceByKeyAndWindowInv( + "larger slide time", + largerSlideInput, + largerSlideOutput, + Seconds(4), + Seconds(2) + ) + + testReduceByKeyAndWindowInv("big test", bigInput, bigOutputInv) +} |