diff options
author | zsxwing <zsxwing@gmail.com> | 2015-05-05 02:15:39 -0700 |
---|---|---|
committer | Tathagata Das <tathagata.das1565@gmail.com> | 2015-05-05 02:15:48 -0700 |
commit | 06345106861568a8a7dcb79372dd95494cc33e27 (patch) | |
tree | 7d9228e96b9bfbf6800205c5673e28573bd52b35 | |
parent | becdb811df35c2d6e3fcd3c70bee00d4f40398c8 (diff) | |
download | spark-06345106861568a8a7dcb79372dd95494cc33e27.tar.gz spark-06345106861568a8a7dcb79372dd95494cc33e27.tar.bz2 spark-06345106861568a8a7dcb79372dd95494cc33e27.zip |
[SPARK-7341] [STREAMING] [TESTS] Fix the flaky test: org.apache.spark.stre...
...aming.InputStreamsSuite.socket input stream
Remove non-deterministic "Thread.sleep" and use deterministic strategies to fix the flaky failure: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-Maven-pre-YARN/hadoop.version=1.0.4,label=centos/2127/testReport/junit/org.apache.spark.streaming/InputStreamsSuite/socket_input_stream/
Author: zsxwing <zsxwing@gmail.com>
Closes #5891 from zsxwing/SPARK-7341 and squashes the following commits:
611157a [zsxwing] Add wait methods to BatchCounter and use BatchCounter in InputStreamsSuite
014b58f [zsxwing] Use withXXX to clean up the resources
c9bf746 [zsxwing] Move 'waitForStart' into the 'start' method and fix the code style
9d0de6d [zsxwing] [SPARK-7341][Streaming][Tests] Fix the flaky test: org.apache.spark.streaming.InputStreamsSuite.socket input stream
(cherry picked from commit 4d29867ede9a87b160c3d715c1fb02067feef449)
Signed-off-by: Tathagata Das <tathagata.das1565@gmail.com>
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala | 169 | ||||
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala | 34 |
2 files changed, 140 insertions, 63 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index eb13675824..6074532502 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.streaming import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.net.{SocketException, ServerSocket} +import java.net.{Socket, SocketException, ServerSocket} import java.nio.charset.Charset -import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} @@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.receiver.Receiver @@ -43,51 +44,57 @@ import org.apache.spark.streaming.receiver.Receiver class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("socket input stream") { - // Start the server - val testServer = new TestServer() - testServer.start() + withTestServer(new TestServer()) { testServer => + // Start the server + testServer.start() - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.socketTextStream( - "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - val expectedOutput = input.map(_.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString + "\n") - Thread.sleep(500) - clock.advance(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val input = Seq(1, 2, 3, 4, 5) + // Use "batchCount" to make sure we check the result after all batches finish + val batchCounter = new BatchCounter(ssc) + val networkStream = ssc.socketTextStream( + "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + outputStream.register() + ssc.start() - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val expectedOutput = input.map(_.toString) + for (i <- 0 until input.size) { + testServer.send(input(i).toString + "\n") + Thread.sleep(500) + clock.advance(batchDuration.milliseconds) + } + // Make sure we finish all batches before "stop" + if (!batchCounter.waitUntilBatchesCompleted(input.size, 30000)) { + fail("Timeout: cannot finish all batches in 30 seconds") + } + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } } } @@ -368,31 +375,45 @@ class TestServer(portToBind: Int = 0) extends Logging { val serverSocket = new ServerSocket(portToBind) + private val startLatch = new CountDownLatch(1) + val servingThread = new Thread() { override def run() { try { while(true) { logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() - logInfo("New connection") - try { - clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter( - new OutputStreamWriter(clientSocket.getOutputStream)) - - while(clientSocket.isConnected) { - val msg = queue.poll(100, TimeUnit.MILLISECONDS) - if (msg != null) { - outputStream.write(msg) - outputStream.flush() - logInfo("Message '" + msg + "' sent") + if (startLatch.getCount == 1) { + // The first connection is a test connection to implement "waitForStart", so skip it + // and send a signal + if (!clientSocket.isClosed) { + clientSocket.close() + } + startLatch.countDown() + } else { + // Real connections + logInfo("New connection") + try { + clientSocket.setTcpNoDelay(true) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) + + while (clientSocket.isConnected) { + val msg = queue.poll(100, TimeUnit.MILLISECONDS) + if (msg != null) { + outputStream.write(msg) + outputStream.flush() + logInfo("Message '" + msg + "' sent") + } + } + } catch { + case e: SocketException => logError("TestServer error", e) + } finally { + logInfo("Connection closed") + if (!clientSocket.isClosed) { + clientSocket.close() } } - } catch { - case e: SocketException => logError("TestServer error", e) - } finally { - logInfo("Connection closed") - if (!clientSocket.isClosed) clientSocket.close() } } } catch { @@ -404,7 +425,29 @@ class TestServer(portToBind: Int = 0) extends Logging { } } - def start() { servingThread.start() } + def start(): Unit = { + servingThread.start() + if (!waitForStart(10000)) { + stop() + throw new AssertionError("Timeout: TestServer cannot start in 10 seconds") + } + } + + /** + * Wait until the server starts. Return true if the server starts in "millis" milliseconds. + * Otherwise, return false to indicate it's timeout. + */ + private def waitForStart(millis: Long): Boolean = { + // We will create a test connection to the server so that we can make sure it has started. + val socket = new Socket("localhost", port) + try { + startLatch.await(millis, TimeUnit.MILLISECONDS) + } finally { + if (!socket.isClosed) { + socket.close() + } + } + } def send(msg: String) { queue.put(msg) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 2ba86aeaf9..4d0cd7516f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -146,6 +146,40 @@ class BatchCounter(ssc: StreamingContext) { def getNumStartedBatches: Int = this.synchronized { numStartedBatches } + + /** + * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if + * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's + * timeout. + * + * @param expectedNumCompletedBatches the `expectedNumCompletedBatches` batches to wait + * @param timeout the maximum time to wait in milliseconds. + */ + def waitUntilBatchesCompleted(expectedNumCompletedBatches: Int, timeout: Long): Boolean = + waitUntilConditionBecomeTrue(numCompletedBatches >= expectedNumCompletedBatches, timeout) + + /** + * Wait until `expectedNumStartedBatches` batches are completed, or timeout. Return true if + * `expectedNumStartedBatches` batches are completed. Otherwise, return false to indicate it's + * timeout. + * + * @param expectedNumStartedBatches the `expectedNumStartedBatches` batches to wait + * @param timeout the maximum time to wait in milliseconds. + */ + def waitUntilBatchesStarted(expectedNumStartedBatches: Int, timeout: Long): Boolean = + waitUntilConditionBecomeTrue(numStartedBatches >= expectedNumStartedBatches, timeout) + + private def waitUntilConditionBecomeTrue(condition: => Boolean, timeout: Long): Boolean = { + synchronized { + var now = System.currentTimeMillis() + val timeoutTick = now + timeout + while (!condition && timeoutTick > now) { + wait(timeoutTick - now) + now = System.currentTimeMillis() + } + condition + } + } } /** |