aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-05-05 02:15:39 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-05-05 02:15:39 -0700
commit4d29867ede9a87b160c3d715c1fb02067feef449 (patch)
tree9132f0dfe71737c4f964798e4d7e5f865b7545f2
parent8436f7e98e674020007a9175973c6a1095b6774f (diff)
downloadspark-4d29867ede9a87b160c3d715c1fb02067feef449.tar.gz
spark-4d29867ede9a87b160c3d715c1fb02067feef449.tar.bz2
spark-4d29867ede9a87b160c3d715c1fb02067feef449.zip
[SPARK-7341] [STREAMING] [TESTS] Fix the flaky test: org.apache.spark.stre...
...aming.InputStreamsSuite.socket input stream Remove non-deterministic "Thread.sleep" and use deterministic strategies to fix the flaky failure: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-Maven-pre-YARN/hadoop.version=1.0.4,label=centos/2127/testReport/junit/org.apache.spark.streaming/InputStreamsSuite/socket_input_stream/ Author: zsxwing <zsxwing@gmail.com> Closes #5891 from zsxwing/SPARK-7341 and squashes the following commits: 611157a [zsxwing] Add wait methods to BatchCounter and use BatchCounter in InputStreamsSuite 014b58f [zsxwing] Use withXXX to clean up the resources c9bf746 [zsxwing] Move 'waitForStart' into the 'start' method and fix the code style 9d0de6d [zsxwing] [SPARK-7341][Streaming][Tests] Fix the flaky test: org.apache.spark.streaming.InputStreamsSuite.socket input stream
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala169
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala34
2 files changed, 140 insertions, 63 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index eb13675824..6074532502 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -18,9 +18,9 @@
package org.apache.spark.streaming
import java.io.{File, BufferedWriter, OutputStreamWriter}
-import java.net.{SocketException, ServerSocket}
+import java.net.{Socket, SocketException, ServerSocket}
import java.nio.charset.Charset
-import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue}
@@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.receiver.Receiver
@@ -43,51 +44,57 @@ import org.apache.spark.streaming.receiver.Receiver
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
test("socket input stream") {
- // Start the server
- val testServer = new TestServer()
- testServer.start()
+ withTestServer(new TestServer()) { testServer =>
+ // Start the server
+ testServer.start()
- // Set up the streaming context and input streams
- val ssc = new StreamingContext(conf, batchDuration)
- val networkStream = ssc.socketTextStream(
- "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
- val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
- val outputStream = new TestOutputStream(networkStream, outputBuffer)
- def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
- outputStream.register()
- ssc.start()
-
- // Feed data to the server to send to the network receiver
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- val input = Seq(1, 2, 3, 4, 5)
- val expectedOutput = input.map(_.toString)
- Thread.sleep(1000)
- for (i <- 0 until input.size) {
- testServer.send(input(i).toString + "\n")
- Thread.sleep(500)
- clock.advance(batchDuration.milliseconds)
- }
- Thread.sleep(1000)
- logInfo("Stopping server")
- testServer.stop()
- logInfo("Stopping context")
- ssc.stop()
-
- // Verify whether data received was as expected
- logInfo("--------------------------------")
- logInfo("output.size = " + outputBuffer.size)
- logInfo("output")
- outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
- logInfo("expected output.size = " + expectedOutput.size)
- logInfo("expected output")
- expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
- logInfo("--------------------------------")
+ // Set up the streaming context and input streams
+ withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
+ val input = Seq(1, 2, 3, 4, 5)
+ // Use "batchCount" to make sure we check the result after all batches finish
+ val batchCounter = new BatchCounter(ssc)
+ val networkStream = ssc.socketTextStream(
+ "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+ val outputStream = new TestOutputStream(networkStream, outputBuffer)
+ outputStream.register()
+ ssc.start()
- // Verify whether all the elements received are as expected
- // (whether the elements were received one in each interval is not verified)
- assert(output.size === expectedOutput.size)
- for (i <- 0 until output.size) {
- assert(output(i) === expectedOutput(i))
+ // Feed data to the server to send to the network receiver
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val expectedOutput = input.map(_.toString)
+ for (i <- 0 until input.size) {
+ testServer.send(input(i).toString + "\n")
+ Thread.sleep(500)
+ clock.advance(batchDuration.milliseconds)
+ }
+ // Make sure we finish all batches before "stop"
+ if (!batchCounter.waitUntilBatchesCompleted(input.size, 30000)) {
+ fail("Timeout: cannot finish all batches in 30 seconds")
+ }
+ logInfo("Stopping server")
+ testServer.stop()
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Verify whether all the elements received are as expected
+ // (whether the elements were received one in each interval is not verified)
+ val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
+ assert(output.size === expectedOutput.size)
+ for (i <- 0 until output.size) {
+ assert(output(i) === expectedOutput(i))
+ }
+ }
}
}
@@ -368,31 +375,45 @@ class TestServer(portToBind: Int = 0) extends Logging {
val serverSocket = new ServerSocket(portToBind)
+ private val startLatch = new CountDownLatch(1)
+
val servingThread = new Thread() {
override def run() {
try {
while(true) {
logInfo("Accepting connections on port " + port)
val clientSocket = serverSocket.accept()
- logInfo("New connection")
- try {
- clientSocket.setTcpNoDelay(true)
- val outputStream = new BufferedWriter(
- new OutputStreamWriter(clientSocket.getOutputStream))
-
- while(clientSocket.isConnected) {
- val msg = queue.poll(100, TimeUnit.MILLISECONDS)
- if (msg != null) {
- outputStream.write(msg)
- outputStream.flush()
- logInfo("Message '" + msg + "' sent")
+ if (startLatch.getCount == 1) {
+ // The first connection is a test connection to implement "waitForStart", so skip it
+ // and send a signal
+ if (!clientSocket.isClosed) {
+ clientSocket.close()
+ }
+ startLatch.countDown()
+ } else {
+ // Real connections
+ logInfo("New connection")
+ try {
+ clientSocket.setTcpNoDelay(true)
+ val outputStream = new BufferedWriter(
+ new OutputStreamWriter(clientSocket.getOutputStream))
+
+ while (clientSocket.isConnected) {
+ val msg = queue.poll(100, TimeUnit.MILLISECONDS)
+ if (msg != null) {
+ outputStream.write(msg)
+ outputStream.flush()
+ logInfo("Message '" + msg + "' sent")
+ }
+ }
+ } catch {
+ case e: SocketException => logError("TestServer error", e)
+ } finally {
+ logInfo("Connection closed")
+ if (!clientSocket.isClosed) {
+ clientSocket.close()
}
}
- } catch {
- case e: SocketException => logError("TestServer error", e)
- } finally {
- logInfo("Connection closed")
- if (!clientSocket.isClosed) clientSocket.close()
}
}
} catch {
@@ -404,7 +425,29 @@ class TestServer(portToBind: Int = 0) extends Logging {
}
}
- def start() { servingThread.start() }
+ def start(): Unit = {
+ servingThread.start()
+ if (!waitForStart(10000)) {
+ stop()
+ throw new AssertionError("Timeout: TestServer cannot start in 10 seconds")
+ }
+ }
+
+ /**
+ * Wait until the server starts. Return true if the server starts in "millis" milliseconds.
+ * Otherwise, return false to indicate it's timeout.
+ */
+ private def waitForStart(millis: Long): Boolean = {
+ // We will create a test connection to the server so that we can make sure it has started.
+ val socket = new Socket("localhost", port)
+ try {
+ startLatch.await(millis, TimeUnit.MILLISECONDS)
+ } finally {
+ if (!socket.isClosed) {
+ socket.close()
+ }
+ }
+ }
def send(msg: String) { queue.put(msg) }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 2ba86aeaf9..4d0cd7516f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -146,6 +146,40 @@ class BatchCounter(ssc: StreamingContext) {
def getNumStartedBatches: Int = this.synchronized {
numStartedBatches
}
+
+ /**
+ * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
+ * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
+ * timeout.
+ *
+ * @param expectedNumCompletedBatches the `expectedNumCompletedBatches` batches to wait
+ * @param timeout the maximum time to wait in milliseconds.
+ */
+ def waitUntilBatchesCompleted(expectedNumCompletedBatches: Int, timeout: Long): Boolean =
+ waitUntilConditionBecomeTrue(numCompletedBatches >= expectedNumCompletedBatches, timeout)
+
+ /**
+ * Wait until `expectedNumStartedBatches` batches are completed, or timeout. Return true if
+ * `expectedNumStartedBatches` batches are completed. Otherwise, return false to indicate it's
+ * timeout.
+ *
+ * @param expectedNumStartedBatches the `expectedNumStartedBatches` batches to wait
+ * @param timeout the maximum time to wait in milliseconds.
+ */
+ def waitUntilBatchesStarted(expectedNumStartedBatches: Int, timeout: Long): Boolean =
+ waitUntilConditionBecomeTrue(numStartedBatches >= expectedNumStartedBatches, timeout)
+
+ private def waitUntilConditionBecomeTrue(condition: => Boolean, timeout: Long): Boolean = {
+ synchronized {
+ var now = System.currentTimeMillis()
+ val timeoutTick = now + timeout
+ while (!condition && timeoutTick > now) {
+ wait(timeoutTick - now)
+ now = System.currentTimeMillis()
+ }
+ condition
+ }
+ }
}
/**