aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java21
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala20
-rw-r--r--docs/streaming-programming-guide.md4
-rw-r--r--project/SparkBuild.scala1
-rw-r--r--streaming/pom.xml4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala6
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java33
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala36
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala38
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala61
17 files changed, 272 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index f9b6ee351a..043cb183ba 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -94,6 +94,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
+
+ /**
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index b3eb739f4e..2142fd7327 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -108,6 +108,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 662990049b..3b359a8fd6 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -82,6 +82,17 @@ JavaRDDLike[T, JavaRDD[T]] {
rdd.coalesce(numPartitions, shuffle)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0355618e43..6e88be6f6a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -266,6 +266,19 @@ abstract class RDD[T: ClassManifest](
def distinct(): RDD[T] = distinct(partitions.size)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): RDD[T] = {
+ coalesce(numPartitions, true)
+ }
+
+ /**
* Return a new RDD that is reduced into `numPartitions` partitions.
*
* This results in a narrow dependency, e.g. if you go from 1000 partitions
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 7b0bb89ab2..352036f182 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -473,6 +473,27 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void repartition() {
+ // Shrinking number of partitions
+ JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
+ JavaRDD<Integer> repartitioned1 = in1.repartition(4);
+ List<List<Integer>> result1 = repartitioned1.glom().collect();
+ Assert.assertEquals(4, result1.size());
+ for (List<Integer> l: result1) {
+ Assert.assertTrue(l.size() > 0);
+ }
+
+ // Growing number of partitions
+ JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
+ JavaRDD<Integer> repartitioned2 = in2.repartition(2);
+ List<List<Integer>> result2 = repartitioned2.glom().collect();
+ Assert.assertEquals(2, result2.size());
+ for (List<Integer> l: result2) {
+ Assert.assertTrue(l.size() > 0);
+ }
+ }
+
+ @Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 6d1bc5e296..354ab8ae5d 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -139,6 +139,26 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.union(emptyKv).collect().size === 2)
}
+ test("repartitioned RDDs") {
+ val data = sc.parallelize(1 to 1000, 10)
+
+ // Coalesce partitions
+ val repartitioned1 = data.repartition(2)
+ assert(repartitioned1.partitions.size == 2)
+ val partitions1 = repartitioned1.glom().collect()
+ assert(partitions1(0).length > 0)
+ assert(partitions1(1).length > 0)
+ assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
+
+ // Split partitions
+ val repartitioned2 = data.repartition(20)
+ assert(repartitioned2.partitions.size == 20)
+ val partitions2 = repartitioned2.glom().collect()
+ assert(partitions2(0).length > 0)
+ assert(partitions2(19).length > 0)
+ assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
+ }
+
test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10)
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 835b257238..851e30fe76 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -73,6 +73,10 @@ DStreams support many of the transformations available on normal Spark RDD's:
Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
</tr>
<tr>
+ <td> <b>repartition</b>(<i>numPartitions</i>) </td>
+ <td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td>
+</tr>
+<tr>
<td> <b>union</b>(<i>otherStream</i>) </td>
<td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>
</tr>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 8d7cbae821..45fd30a7c8 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -292,6 +292,7 @@ object SparkBuild extends Build {
"org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
exclude("com.sun.jdmk", "jmxtools")
exclude("com.sun.jmx", "jmxri")
+ exclude("net.sf.jopt-simple", "jopt-simple")
)
)
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 8022c4fe18..7a9ae6a97b 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -73,6 +73,10 @@
<groupId>com.sun.jdmk</groupId>
<artifactId>jmxtools</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>net.sf.jopt-simple</groupId>
+ <artifactId>jopt-simple</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 38e34795b4..9ceff754c4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] (
*/
def glom(): DStream[Array[T]] = new GlommedDStream(this)
+
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions))
+
/**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
index d1932b6b05..1a2aeaa879 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
@@ -94,6 +94,12 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM
*/
def union(that: JavaDStream[T]): JavaDStream[T] =
dstream.union(that.dstream)
+
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): JavaDStream[T] = dstream.repartition(numPartitions)
}
object JavaDStream {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 4dd6b7d096..c6cd635afa 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -59,6 +59,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/** Persist the RDDs of this DStream with the given storage level */
def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel)
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions)
+
/** Method that generates a RDD for the given Duration */
def compute(validTime: Time): JavaPairRDD[K, V] = {
dstream.compute(validTime) match {
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 5d48908667..ad4a8b9535 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -185,6 +185,39 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void testRepartitionMorePartitions() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 2);
+ JavaDStream repartitioned = stream.repartition(4);
+ JavaTestUtils.attachTestOutputStream(repartitioned);
+ List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+ Assert.assertEquals(2, result.size());
+ for (List<List<Integer>> rdd : result) {
+ Assert.assertEquals(4, rdd.size());
+ Assert.assertEquals(
+ 10, rdd.get(0).size() + rdd.get(1).size() + rdd.get(2).size() + rdd.get(3).size());
+ }
+ }
+
+ @Test
+ public void testRepartitionFewerPartitions() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 4);
+ JavaDStream repartitioned = stream.repartition(2);
+ JavaTestUtils.attachTestOutputStream(repartitioned);
+ List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+ Assert.assertEquals(2, result.size());
+ for (List<List<Integer>> rdd : result) {
+ Assert.assertEquals(2, rdd.size());
+ Assert.assertEquals(10, rdd.get(0).size() + rdd.get(1).size());
+ }
+ }
+
+ @Test
public void testGlom() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("giants", "dodgers"),
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
index 8a6604904d..5e384eeee4 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
@@ -33,9 +33,9 @@ trait JavaTestBase extends TestSuiteBase {
* The stream will be derived from the supplied lists of Java objects.
**/
def attachTestInputStream[T](
- ssc: JavaStreamingContext,
- data: JList[JList[T]],
- numPartitions: Int) = {
+ ssc: JavaStreamingContext,
+ data: JList[JList[T]],
+ numPartitions: Int) = {
val seqData = data.map(Seq(_:_*))
implicit val cm: ClassManifest[T] =
@@ -50,12 +50,11 @@ trait JavaTestBase extends TestSuiteBase {
* [[org.apache.spark.streaming.TestOutputStream]].
**/
def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]](
- dstream: JavaDStreamLike[T, This, R]) =
+ dstream: JavaDStreamLike[T, This, R]) =
{
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- val ostream = new TestOutputStream(dstream.dstream,
- new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]])
+ val ostream = new TestOutputStreamWithPartitions(dstream.dstream)
dstream.dstream.ssc.registerOutputStream(ostream)
}
@@ -63,9 +62,11 @@ trait JavaTestBase extends TestSuiteBase {
* Process all registered streams for a numBatches batches, failing if
* numExpectedOutput RDD's are not generated. Generated RDD's are collected
* and returned, represented as a list for each batch interval.
+ *
+ * Returns a list of items for each RDD.
*/
def runStreams[V](
- ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
+ ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
implicit val cm: ClassManifest[V] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput)
@@ -73,6 +74,27 @@ trait JavaTestBase extends TestSuiteBase {
res.map(entry => out.append(new ArrayList[V](entry)))
out
}
+
+ /**
+ * Process all registered streams for a numBatches batches, failing if
+ * numExpectedOutput RDD's are not generated. Generated RDD's are collected
+ * and returned, represented as a list for each batch interval.
+ *
+ * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+ * representing one partition.
+ */
+ def runStreamsWithPartitions[V](ssc: JavaStreamingContext, numBatches: Int,
+ numExpectedOutput: Int): JList[JList[JList[V]]] = {
+ implicit val cm: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput)
+ val out = new ArrayList[JList[JList[V]]]()
+ res.map{entry =>
+ val lists = entry.map(new ArrayList[V](_))
+ out.append(new ArrayList[JList[V]](lists))
+ }
+ out
+ }
}
object JavaTestUtils extends JavaTestBase {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index a2ac510a98..259ef1608c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -85,6 +85,44 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(input, operation, output, true)
}
+ test("repartition (more partitions)") {
+ val input = Seq(1 to 100, 101 to 200, 201 to 300)
+ val operation = (r: DStream[Int]) => r.repartition(5)
+ val ssc = setupStreams(input, operation, 2)
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 5)
+ assert(second.size === 5)
+ assert(third.size === 5)
+
+ assert(first.flatten.toSet === (1 to 100).toSet)
+ assert(second.flatten.toSet === (101 to 200).toSet)
+ assert(third.flatten.toSet === (201 to 300).toSet)
+ }
+
+ test("repartition (fewer partitions)") {
+ val input = Seq(1 to 100, 101 to 200, 201 to 300)
+ val operation = (r: DStream[Int]) => r.repartition(2)
+ val ssc = setupStreams(input, operation, 5)
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 2)
+ assert(second.size === 2)
+ assert(third.size === 2)
+
+ assert(first.flatten.toSet === (1 to 100).toSet)
+ assert(second.flatten.toSet === (101 to 200).toSet)
+ assert(third.flatten.toSet === (201 to 300).toSet)
+ }
+
test("groupByKey") {
testOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index a327de80b3..beb20831bd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -366,7 +366,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
logInfo("Manual clock after advancing = " + clock.time)
Thread.sleep(batchDuration.milliseconds)
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
- outputStream.output
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+ outputStream.output.map(_.flatten)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 37dd9c4cc6..be140699c2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -60,8 +60,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
/**
* This is a output stream just for the testsuites. All the output is collected into a
* ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of items
*/
-class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
+class TestOutputStream[T: ClassManifest](parent: DStream[T],
+ val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]())
extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
val collected = rdd.collect()
output += collected
@@ -76,6 +79,30 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu
}
/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each
+ * containing a sequence of items.
+ */
+class TestOutputStreamWithPartitions[T: ClassManifest](parent: DStream[T],
+ val output: ArrayBuffer[Seq[Seq[T]]] = ArrayBuffer[Seq[Seq[T]]]())
+ extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ val collected = rdd.glom().collect().map(_.toSeq)
+ output += collected
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+
+ def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten))
+}
+
+/**
* This is the base trait for Spark Streaming testsuites. This provides basic functionality
* to run user-defined set of input on user-defined stream operations, and verify the output.
*/
@@ -108,7 +135,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
*/
def setupStreams[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V]
+ operation: DStream[U] => DStream[V],
+ numPartitions: Int = numInputPartitions
): StreamingContext = {
// Create StreamingContext
@@ -118,9 +146,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
}
// Setup the stream computation
- val inputStream = new TestInputStream(ssc, input, numInputPartitions)
+ val inputStream = new TestInputStream(ssc, input, numPartitions)
val operatedStream = operation(inputStream)
- val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
ssc.registerInputStream(inputStream)
ssc.registerOutputStream(outputStream)
ssc
@@ -146,7 +175,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions)
val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions)
val operatedStream = operation(inputStream1, inputStream2)
- val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]])
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]])
ssc.registerInputStream(inputStream1)
ssc.registerInputStream(inputStream2)
ssc.registerOutputStream(outputStream)
@@ -157,18 +187,37 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
* returns the collected output. It will wait until `numExpectedOutput` number of
* output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ *
+ * Returns a sequence of items for each RDD.
*/
def runStreams[V: ClassManifest](
ssc: StreamingContext,
numBatches: Int,
numExpectedOutput: Int
): Seq[Seq[V]] = {
+ // Flatten each RDD into a single Seq
+ runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq)
+ }
+
+ /**
+ * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
+ * returns the collected output. It will wait until `numExpectedOutput` number of
+ * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ *
+ * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+ * representing one partition.
+ */
+ def runStreamsWithPartitions[V: ClassManifest](
+ ssc: StreamingContext,
+ numBatches: Int,
+ numExpectedOutput: Int
+ ): Seq[Seq[Seq[V]]] = {
assert(numBatches > 0, "Number of batches to run stream computation is zero")
assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
// Get the output buffer
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
val output = outputStream.output
try {