aboutsummaryrefslogtreecommitdiff
path: root/streaming/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'streaming/src/test')
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala108
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala2
3 files changed, 98 insertions, 16 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index bcb0c28bf0..bb73dbf29b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -324,7 +324,7 @@ class BasicOperationsSuite extends TestSuiteBase {
val updateStateOperation = (s: DStream[String]) => {
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
- Some(values.foldLeft(0)(_ + _) + state.getOrElse(0))
+ Some(values.sum + state.getOrElse(0))
}
s.map(x => (x, 1)).updateStateByKey[Int](updateFunc)
}
@@ -359,7 +359,7 @@ class BasicOperationsSuite extends TestSuiteBase {
// updateFunc clears a state when a StateObject is seen without new values twice in a row
val updateFunc = (values: Seq[Int], state: Option[StateObject]) => {
val stateObj = state.getOrElse(new StateObject)
- values.foldLeft(0)(_ + _) match {
+ values.sum match {
case 0 => stateObj.expireCounter += 1 // no new values
case n => { // has new values, increment and reset expireCounter
stateObj.counter += n
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 717da8e004..9cc27ef7f0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,19 +17,22 @@
package org.apache.spark.streaming
-import org.scalatest.{FunSuite, BeforeAndAfter}
-import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.dstream.{DStream, NetworkReceiver}
+import org.apache.spark.util.{MetadataCleaner, Utils}
+import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Timeouts
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkException, SparkConf, SparkContext}
-import org.apache.spark.util.{Utils, MetadataCleaner}
-import org.apache.spark.streaming.dstream.DStream
-class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
+class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
val master = "local[2]"
val appName = this.getClass.getSimpleName
- val batchDuration = Seconds(1)
+ val batchDuration = Milliseconds(500)
val sparkHome = "someDir"
val envPair = "key" -> "value"
val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100
@@ -108,19 +111,31 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
myConf.set("spark.cleaner.ttl", ttl.toString)
val ssc1 = new StreamingContext(myConf, batchDuration)
+ addInputStream(ssc1).register
+ ssc1.start()
val cp = new Checkpoint(ssc1, Time(1000))
assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl)
ssc1.stop()
val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl)
- ssc = new StreamingContext(null, cp, null)
+ ssc = new StreamingContext(null, newCp, null)
assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl)
}
- test("start multiple times") {
+ test("start and stop state check") {
ssc = new StreamingContext(master, appName, batchDuration)
addInputStream(ssc).register
+ assert(ssc.state === ssc.StreamingContextState.Initialized)
+ ssc.start()
+ assert(ssc.state === ssc.StreamingContextState.Started)
+ ssc.stop()
+ assert(ssc.state === ssc.StreamingContextState.Stopped)
+ }
+
+ test("start multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
ssc.start()
intercept[SparkException] {
ssc.start()
@@ -133,18 +148,61 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
ssc.start()
ssc.stop()
ssc.stop()
- ssc = null
}
+ test("stop before start and start after stop") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+ ssc.stop() // stop before start should not throw exception
+ ssc.start()
+ ssc.stop()
+ intercept[SparkException] {
+ ssc.start() // start after stop should throw exception
+ }
+ }
+
+
test("stop only streaming context") {
ssc = new StreamingContext(master, appName, batchDuration)
sc = ssc.sparkContext
addInputStream(ssc).register
ssc.start()
ssc.stop(false)
- ssc = null
assert(sc.makeRDD(1 to 100).collect().size === 100)
ssc = new StreamingContext(sc, batchDuration)
+ addInputStream(ssc).register
+ ssc.start()
+ ssc.stop()
+ }
+
+ test("stop gracefully") {
+ val conf = new SparkConf().setMaster(master).setAppName(appName)
+ conf.set("spark.cleaner.ttl", "3600")
+ sc = new SparkContext(conf)
+ for (i <- 1 to 4) {
+ logInfo("==================================")
+ ssc = new StreamingContext(sc, batchDuration)
+ var runningCount = 0
+ TestReceiver.counter.set(1)
+ val input = ssc.networkStream(new TestReceiver)
+ input.count.foreachRDD(rdd => {
+ val count = rdd.first()
+ logInfo("Count = " + count)
+ runningCount += count.toInt
+ })
+ ssc.start()
+ ssc.awaitTermination(500)
+ ssc.stop(stopSparkContext = false, stopGracefully = true)
+ logInfo("Running count = " + runningCount)
+ logInfo("TestReceiver.counter = " + TestReceiver.counter.get())
+ assert(runningCount > 0)
+ assert(
+ (TestReceiver.counter.get() == runningCount + 1) ||
+ (TestReceiver.counter.get() == runningCount + 2),
+ "Received records = " + TestReceiver.counter.get() + ", " +
+ "processed records = " + runningCount
+ )
+ }
}
test("awaitTermination") {
@@ -199,7 +257,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
test("awaitTermination with error in job generation") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
-
inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register
val exception = intercept[TestException] {
ssc.start()
@@ -215,4 +272,29 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
}
}
-class TestException(msg: String) extends Exception(msg) \ No newline at end of file
+class TestException(msg: String) extends Exception(msg)
+
+/** Custom receiver for testing whether all data received by a receiver gets processed or not */
+class TestReceiver extends NetworkReceiver[Int] {
+ protected lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
+ protected def onStart() {
+ blockGenerator.start()
+ logInfo("BlockGenerator started on thread " + receivingThread)
+ try {
+ while(true) {
+ blockGenerator += TestReceiver.counter.getAndIncrement
+ Thread.sleep(0)
+ }
+ } finally {
+ logInfo("Receiving stopped at count value of " + TestReceiver.counter.get())
+ }
+ }
+
+ protected def onStop() {
+ blockGenerator.stop()
+ }
+}
+
+object TestReceiver {
+ val counter = new AtomicInteger(1)
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 201630672a..aa2d5c2fc2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -277,7 +277,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
- Thread.sleep(500) // Give some time for the forgetting old RDDs to complete
+ Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
} catch {
case e: Exception => {e.printStackTrace(); throw e}
} finally {