From 7ccbbdacb9406d67b5acf2a489d6551900babdc9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 12 Nov 2013 00:10:45 -0800 Subject: Made block generator thread safe to fix Kafka bug. --- .../streaming/dstream/NetworkInputDStream.scala | 4 +- .../apache/spark/streaming/InputStreamsSuite.scala | 83 ++++++++++++++++++++-- 2 files changed, 80 insertions(+), 7 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 8d3ac0fc65..a82862c802 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -232,11 +232,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log logInfo("Data handler stopped") } - def += (obj: T) { + def += (obj: T): Unit = synchronized { currentBuffer += obj } - private def updateCurrentBuffer(time: Long) { + private def updateCurrentBuffer(time: Long): Unit = synchronized { try { val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[T] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index c29b75ece6..a559db468a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -23,15 +23,15 @@ import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString -import dstream.SparkFlumeEvent +import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent} import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} +import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.ManualClock import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.receivers.Receiver -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import scala.util.Random import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter @@ -44,6 +44,7 @@ import java.nio.ByteBuffer import collection.JavaConversions._ import java.nio.charset.Charset import com.google.common.io.Files +import java.util.concurrent.atomic.AtomicInteger class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.clearProperty("spark.hostPort") } - test("socket input stream") { // Start the server val testServer = new TestServer() @@ -275,10 +275,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { kafka.serializer.StringDecoder, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) } + + test("multi-thread receiver") { + // set up the test receiver + val numThreads = 10 + val numRecordsPerThread = 1000 + val numTotalRecords = numThreads * numRecordsPerThread + val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) + MultiThreadTestReceiver.haveAllThreadsFinished = false + + // set up the network stream using the test receiver + val ssc = new StreamingContext(master, framework, batchDuration) + val networkStream = ssc.networkStream[Int](testReceiver) + val countStream = networkStream.count + val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] + val outputStream = new TestOutputStream(countStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + assert(output.sum === numTotalRecords) + } } -/** This is server to test the network input stream */ +/** This is a server to test the network input stream */ class TestServer() extends Logging { val queue = new ArrayBlockingQueue[String](100) @@ -340,6 +379,7 @@ object TestServer { } } +/** This is an actor for testing actor input stream */ class TestActor(port: Int) extends Actor with Receiver { def bytesToString(byteString: ByteString) = byteString.utf8String @@ -351,3 +391,36 @@ class TestActor(port: Int) extends Actor with Receiver { pushBlock(bytesToString(bytes)) } } + +/** This is a receiver to test multiple threads inserting data using block generator */ +class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) + extends NetworkReceiver[Int] { + lazy val executorPool = Executors.newFixedThreadPool(numThreads) + lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY) + lazy val finishCount = new AtomicInteger(0) + + protected def onStart() { + blockGenerator.start() + (1 to numThreads).map(threadId => { + val runnable = new Runnable { + def run() { + (1 to numRecordsPerThread).foreach(i => + blockGenerator += (threadId * numRecordsPerThread + i) ) + if (finishCount.incrementAndGet == numThreads) { + MultiThreadTestReceiver.haveAllThreadsFinished = true + } + logInfo("Finished thread " + threadId) + } + } + executorPool.submit(runnable) + }) + } + + protected def onStop() { + executorPool.shutdown() + } +} + +object MultiThreadTestReceiver { + var haveAllThreadsFinished = false +} -- cgit v1.2.3