aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2013-11-12 00:10:45 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2013-11-12 00:10:45 -0800
commit7ccbbdacb9406d67b5acf2a489d6551900babdc9 (patch)
tree8592be6ee2e99905ff9d444e81999c8c654507e3 /streaming
parent23b53efccdc6068d6781ae3ed7b656ae5008c8fb (diff)
downloadspark-7ccbbdacb9406d67b5acf2a489d6551900babdc9.tar.gz
spark-7ccbbdacb9406d67b5acf2a489d6551900babdc9.tar.bz2
spark-7ccbbdacb9406d67b5acf2a489d6551900babdc9.zip
Made block generator thread safe to fix Kafka bug.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala83
2 files changed, 80 insertions, 7 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 8d3ac0fc65..a82862c802 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -232,11 +232,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
logInfo("Data handler stopped")
}
- def += (obj: T) {
+ def += (obj: T): Unit = synchronized {
currentBuffer += obj
}
- private def updateCurrentBuffer(time: Long) {
+ private def updateCurrentBuffer(time: Long): Unit = synchronized {
try {
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index c29b75ece6..a559db468a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -23,15 +23,15 @@ import akka.actor.IOManager
import akka.actor.Props
import akka.util.ByteString
-import dstream.SparkFlumeEvent
+import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent}
import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter}
-import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
+import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import util.ManualClock
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receivers.Receiver
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, Logging}
import scala.util.Random
import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter
@@ -44,6 +44,7 @@ import java.nio.ByteBuffer
import collection.JavaConversions._
import java.nio.charset.Charset
import com.google.common.io.Files
+import java.util.concurrent.atomic.AtomicInteger
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
@@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
System.clearProperty("spark.hostPort")
}
-
test("socket input stream") {
// Start the server
val testServer = new TestServer()
@@ -275,10 +275,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
kafka.serializer.StringDecoder,
kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
}
+
+ test("multi-thread receiver") {
+ // set up the test receiver
+ val numThreads = 10
+ val numRecordsPerThread = 1000
+ val numTotalRecords = numThreads * numRecordsPerThread
+ val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread)
+ MultiThreadTestReceiver.haveAllThreadsFinished = false
+
+ // set up the network stream using the test receiver
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val networkStream = ssc.networkStream[Int](testReceiver)
+ val countStream = networkStream.count
+ val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]]
+ val outputStream = new TestOutputStream(countStream, outputBuffer)
+ def output = outputBuffer.flatMap(x => x)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Let the data from the receiver be received
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val startTime = System.currentTimeMillis()
+ while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) &&
+ System.currentTimeMillis() - startTime < 5000) {
+ Thread.sleep(100)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(1000)
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+ assert(output.sum === numTotalRecords)
+ }
}
-/** This is server to test the network input stream */
+/** This is a server to test the network input stream */
class TestServer() extends Logging {
val queue = new ArrayBlockingQueue[String](100)
@@ -340,6 +379,7 @@ object TestServer {
}
}
+/** This is an actor for testing actor input stream */
class TestActor(port: Int) extends Actor with Receiver {
def bytesToString(byteString: ByteString) = byteString.utf8String
@@ -351,3 +391,36 @@ class TestActor(port: Int) extends Actor with Receiver {
pushBlock(bytesToString(bytes))
}
}
+
+/** This is a receiver to test multiple threads inserting data using block generator */
+class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int)
+ extends NetworkReceiver[Int] {
+ lazy val executorPool = Executors.newFixedThreadPool(numThreads)
+ lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
+ lazy val finishCount = new AtomicInteger(0)
+
+ protected def onStart() {
+ blockGenerator.start()
+ (1 to numThreads).map(threadId => {
+ val runnable = new Runnable {
+ def run() {
+ (1 to numRecordsPerThread).foreach(i =>
+ blockGenerator += (threadId * numRecordsPerThread + i) )
+ if (finishCount.incrementAndGet == numThreads) {
+ MultiThreadTestReceiver.haveAllThreadsFinished = true
+ }
+ logInfo("Finished thread " + threadId)
+ }
+ }
+ executorPool.submit(runnable)
+ })
+ }
+
+ protected def onStop() {
+ executorPool.shutdown()
+ }
+}
+
+object MultiThreadTestReceiver {
+ var haveAllThreadsFinished = false
+}