aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-10-21 17:40:08 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2012-10-21 17:40:08 -0700
commitd85c66636ba3b5d32f7e3b47c5b68e1064f8f588 (patch)
tree89d5c9b6b5cc84f9e5a8aed3e1f22e461ffa5ebf /streaming
parentc4a2b6f636040bacd3d4b443e65cc33dafd0aa7e (diff)
downloadspark-d85c66636ba3b5d32f7e3b47c5b68e1064f8f588.tar.gz
spark-d85c66636ba3b5d32f7e3b47c5b68e1064f8f588.tar.bz2
spark-d85c66636ba3b5d32f7e3b47c5b68e1064f8f588.zip
Added MapValueDStream, FlatMappedValuesDStream and CoGroupedDStream, and therefore DStream operations mapValue, flatMapValues, cogroup, and join. Also, added tests for DStream operations filter, glom, mapPartitions, groupByKey, mapValues, flatMapValues, cogroup, and join.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala37
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala63
-rw-r--r--streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala75
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/CountRaw.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala86
-rw-r--r--streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala59
9 files changed, 293 insertions, 35 deletions
diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala
new file mode 100644
index 0000000000..5522e2ee21
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala
@@ -0,0 +1,37 @@
+package spark.streaming
+
+import spark.{CoGroupedRDD, RDD, Partitioner}
+
+class CoGroupedDStream[K : ClassManifest](
+ parents: Seq[DStream[(_, _)]],
+ partitioner: Partitioner
+ ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) {
+
+ if (parents.length == 0) {
+ throw new IllegalArgumentException("Empty array of parents")
+ }
+
+ if (parents.map(_.ssc).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different StreamingContexts")
+ }
+
+ if (parents.map(_.slideTime).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different slide times")
+ }
+
+ override def dependencies = parents.toList
+
+ override def slideTime = parents.head.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = {
+ val part = partitioner
+ val rdds = parents.flatMap(_.getOrCompute(validTime))
+ if (rdds.size > 0) {
+ val q = new CoGroupedRDD[K](rdds, part)
+ Some(q)
+ } else {
+ None
+ }
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index f6cd135e59..38bb7c8b94 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -244,27 +244,27 @@ extends Serializable with Logging {
* DStream operations
* --------------
*/
- def map[U: ClassManifest](mapFunc: T => U) = {
+ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = {
new MappedDStream(this, ssc.sc.clean(mapFunc))
}
- def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = {
+ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = {
new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc))
}
- def filter(filterFunc: T => Boolean) = new FilteredDStream(this, filterFunc)
+ def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc)
- def glom() = new GlommedDStream(this)
+ def glom(): DStream[Array[T]] = new GlommedDStream(this)
- def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = {
+ def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = {
new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc))
}
- def reduce(reduceFunc: (T, T) => T) = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2)
+ def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2)
- def count() = this.map(_ => 1).reduce(_ + _)
+ def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _)
- def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2)
+ def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2)
def foreach(foreachFunc: T => Unit) {
val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc))
@@ -341,7 +341,7 @@ extends Serializable with Logging {
this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime)
}
- def union(that: DStream[T]): DStream[T] = new UnifiedDStream[T](Array(this, that))
+ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that))
def slice(interval: Interval): Seq[RDD[T]] = {
slice(interval.beginTime, interval.endTime)
@@ -507,8 +507,47 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
* TODO
*/
-class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]])
- extends DStream[T](parents(0).ssc) {
+class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ mapValueFunc: V => U
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc))
+ }
+}
+
+
+/**
+ * TODO
+ */
+
+class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ flatMapValueFunc: V => TraversableOnce[U]
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc))
+ }
+}
+
+
+
+/**
+ * TODO
+ */
+
+class UnionDStream[T: ClassManifest](parents: Array[DStream[T]])
+ extends DStream[T](parents.head.ssc) {
if (parents.length == 0) {
throw new IllegalArgumentException("Empty array of parents")
@@ -524,7 +563,7 @@ class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]])
override def dependencies = parents.toList
- override def slideTime: Time = parents(0).slideTime
+ override def slideTime: Time = parents.head.slideTime
override def compute(validTime: Time): Option[RDD[T]] = {
val rdds = new ArrayBuffer[RDD[T]]()
diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
index 0bd0321928..5de57eb2fd 100644
--- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
@@ -1,17 +1,16 @@
package spark.streaming
import scala.collection.mutable.ArrayBuffer
-import spark.Partitioner
-import spark.HashPartitioner
+import spark.{Manifests, RDD, Partitioner, HashPartitioner}
import spark.streaming.StreamingContext._
import javax.annotation.Nullable
-class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)])
+class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)])
extends Serializable {
- def ssc = stream.ssc
+ def ssc = self.ssc
- def defaultPartitioner(numPartitions: Int = stream.ssc.sc.defaultParallelism) = {
+ def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = {
new HashPartitioner(numPartitions)
}
@@ -28,10 +27,10 @@ extends Serializable {
}
def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = {
- def createCombiner(v: V) = ArrayBuffer[V](v)
- def mergeValue(c: ArrayBuffer[V], v: V) = (c += v)
- def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2)
- combineByKey(createCombiner _, mergeValue _, mergeCombiner _, partitioner).asInstanceOf[DStream[(K, Seq[V])]]
+ val createCombiner = (v: V) => ArrayBuffer[V](v)
+ val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v)
+ val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2)
+ combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner).asInstanceOf[DStream[(K, Seq[V])]]
}
def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = {
@@ -52,7 +51,7 @@ extends Serializable {
mergeValue: (C, V) => C,
mergeCombiner: (C, C) => C,
partitioner: Partitioner) : ShuffledDStream[K, V, C] = {
- new ShuffledDStream[K, V, C](stream, createCombiner, mergeValue, mergeCombiner, partitioner)
+ new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner)
}
def groupByKeyAndWindow(windowTime: Time, slideTime: Time): DStream[(K, Seq[V])] = {
@@ -72,14 +71,14 @@ extends Serializable {
slideTime: Time,
partitioner: Partitioner
): DStream[(K, Seq[V])] = {
- stream.window(windowTime, slideTime).groupByKey(partitioner)
+ self.window(windowTime, slideTime).groupByKey(partitioner)
}
def reduceByKeyAndWindow(
reduceFunc: (V, V) => V,
windowTime: Time
): DStream[(K, V)] = {
- reduceByKeyAndWindow(reduceFunc, windowTime, stream.slideTime, defaultPartitioner())
+ reduceByKeyAndWindow(reduceFunc, windowTime, self.slideTime, defaultPartitioner())
}
def reduceByKeyAndWindow(
@@ -105,7 +104,7 @@ extends Serializable {
slideTime: Time,
partitioner: Partitioner
): DStream[(K, V)] = {
- stream.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner)
+ self.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner)
}
// This method is the efficient sliding window reduce operation,
@@ -148,7 +147,7 @@ extends Serializable {
val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc)
new ReducedWindowedDStream[K, V](
- stream, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner)
+ self, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner)
}
// TODO:
@@ -184,7 +183,53 @@ extends Serializable {
partitioner: Partitioner,
rememberPartitioner: Boolean
): DStream[(K, S)] = {
- new StateDStream(stream, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
+ new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
+ }
+
+
+ def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = {
+ new MapValuesDStream[K, V, U](self, mapValuesFunc)
+ }
+
+ def flatMapValues[U: ClassManifest](
+ flatMapValuesFunc: V => TraversableOnce[U]
+ ): DStream[(K, U)] = {
+ new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc)
+ }
+
+ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, defaultPartitioner())
+ }
+
+ def cogroup[W: ClassManifest](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (Seq[V], Seq[W]))] = {
+
+ val cgd = new CoGroupedDStream[K](
+ Seq(self.asInstanceOf[DStream[(_, _)]], other.asInstanceOf[DStream[(_, _)]]),
+ partitioner
+ )
+ val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)(
+ classManifest[K],
+ Manifests.seqSeqManifest
+ )
+ pdfs.mapValues {
+ case Seq(vs, ws) =>
+ (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
+ }
+ }
+
+ def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = {
+ join[W](other, defaultPartitioner())
+ }
+
+ def join[W: ClassManifest](other: DStream[(K, W)], partitioner: Partitioner): DStream[(K, (V, W))] = {
+ this.cogroup(other, partitioner)
+ .flatMapValues{
+ case (vs, ws) =>
+ for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
+ }
}
}
diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala
index c78c1e9660..ed571d22e3 100644
--- a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala
@@ -25,7 +25,7 @@ object CountRaw {
val rawStreams = (1 to numStreams).map(_ =>
ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
- val union = new UnifiedDStream(rawStreams)
+ val union = new UnionDStream(rawStreams)
union.map(_.length + 2).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString))
ssc.start()
}
diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala
index cc52da7bd4..6af1c36891 100644
--- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala
@@ -25,7 +25,7 @@ object GrepRaw {
val rawStreams = (1 to numStreams).map(_ =>
ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
- val union = new UnifiedDStream(rawStreams)
+ val union = new UnionDStream(rawStreams)
union.filter(_.contains("Culpepper")).count().foreachRDD(r =>
println("Grep count: " + r.collect().mkString))
ssc.start()
diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
index 3ba07d0448..af0a3bf98a 100644
--- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
@@ -37,7 +37,7 @@ object TopKWordCountRaw {
val rawStreams = (1 to streams).map(_ =>
ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
- val union = new UnifiedDStream(rawStreams)
+ val union = new UnionDStream(rawStreams)
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
index 9702003805..98bafec529 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
@@ -37,7 +37,7 @@ object WordCountRaw {
val rawStreams = (1 to streams).map(_ =>
ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray
- val union = new UnifiedDStream(rawStreams)
+ val union = new UnionDStream(rawStreams)
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala
index 28bbb152ca..db95c2cfaa 100644
--- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala
@@ -24,6 +24,44 @@ class DStreamBasicSuite extends DStreamSuiteBase {
)
}
+ test("filter") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.filter(x => (x % 2 == 0)),
+ input.map(_.filter(x => (x % 2 == 0)))
+ )
+ }
+
+ test("glom") {
+ assert(numInputPartitions === 2, "Number of input partitions has been changed from 2")
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ val output = Seq(
+ Seq( Seq(1, 2), Seq(3, 4) ),
+ Seq( Seq(5, 6), Seq(7, 8) ),
+ Seq( Seq(9, 10), Seq(11, 12) )
+ )
+ val operation = (r: DStream[Int]) => r.glom().map(_.toSeq)
+ testOperation(input, operation, output)
+ }
+
+ test("mapPartitions") {
+ assert(numInputPartitions === 2, "Number of input partitions has been changed from 2")
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23))
+ val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _)))
+ testOperation(input, operation, output, true)
+ }
+
+ test("groupByKey") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(),
+ Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ),
+ true
+ )
+ }
+
test("reduceByKey") {
testOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
@@ -41,6 +79,54 @@ class DStreamBasicSuite extends DStreamSuiteBase {
)
}
+ test("mapValues") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10),
+ Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ),
+ true
+ )
+ }
+
+ test("flatMapValues") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)),
+ Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ),
+ true
+ )
+ }
+
+ test("cogroup") {
+ val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() )
+ val outputData = Seq(
+ Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ),
+ Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ),
+ Seq( ("", (Seq(1), Seq())) ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("join") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (1, "x")), ("b", (1, "x")) ),
+ Seq( ("", (1, "x")) ),
+ Seq( ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x,1)).join(s2.map(x => (x,"x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
test("updateStateByKey") {
val inputData =
Seq(
diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala
index 59fa5a6f22..2a4b37c965 100644
--- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala
@@ -6,7 +6,7 @@ import collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
import collection.mutable.SynchronizedBuffer
-class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, val input: Seq[Seq[T]])
+class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
extends InputDStream[T](ssc_) {
var currentIndex = 0
@@ -17,9 +17,9 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, val input: Seq[
def compute(validTime: Time): Option[RDD[T]] = {
logInfo("Computing RDD for time " + validTime)
val rdd = if (currentIndex < input.size) {
- ssc.sc.makeRDD(input(currentIndex), 2)
+ ssc.sc.makeRDD(input(currentIndex), numPartitions)
} else {
- ssc.sc.makeRDD(Seq[T](), 2)
+ ssc.sc.makeRDD(Seq[T](), numPartitions)
}
logInfo("Created RDD " + rdd.id)
currentIndex += 1
@@ -47,6 +47,8 @@ trait DStreamSuiteBase extends FunSuite with Logging {
def checkpointInterval() = batchDuration
+ def numInputPartitions() = 2
+
def maxWaitTimeMillis() = 10000
def setupStreams[U: ClassManifest, V: ClassManifest](
@@ -62,7 +64,7 @@ trait DStreamSuiteBase extends FunSuite with Logging {
}
// Setup the stream computation
- val inputStream = new TestInputStream(ssc, input)
+ val inputStream = new TestInputStream(ssc, input, numInputPartitions)
val operatedStream = operation(inputStream)
val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
ssc.registerInputStream(inputStream)
@@ -70,6 +72,31 @@ trait DStreamSuiteBase extends FunSuite with Logging {
ssc
}
+ def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W]
+ ): StreamingContext = {
+
+ // Create StreamingContext
+ val ssc = new StreamingContext(master, framework)
+ ssc.setBatchDuration(batchDuration)
+ if (checkpointFile != null) {
+ ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ }
+
+ // Setup the stream computation
+ val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions)
+ val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions)
+ val operatedStream = operation(inputStream1, inputStream2)
+ val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]])
+ ssc.registerInputStream(inputStream1)
+ ssc.registerInputStream(inputStream2)
+ ssc.registerOutputStream(outputStream)
+ ssc
+ }
+
+
def runStreams[V: ClassManifest](
ssc: StreamingContext,
numBatches: Int,
@@ -162,4 +189,28 @@ trait DStreamSuiteBase extends FunSuite with Logging {
val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
verifyOutput[V](output, expectedOutput, useSet)
}
+
+ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W],
+ expectedOutput: Seq[Seq[W]],
+ useSet: Boolean
+ ) {
+ testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet)
+ }
+
+ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W],
+ expectedOutput: Seq[Seq[W]],
+ numBatches: Int,
+ useSet: Boolean
+ ) {
+ val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
+ val ssc = setupStreams[U, V, W](input1, input2, operation)
+ val output = runStreams[W](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[W](output, expectedOutput, useSet)
+ }
}