aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-10-24 14:44:20 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2012-10-24 14:44:20 -0700
commit1ef6ea25135fd33a7913944628b67f24c87db1f5 (patch)
tree1a8dcd3994fa3cd317b901f2882ce9da273af073 /streaming
parent020d6434844b22c2fe611303b338eaf53397c9db (diff)
downloadspark-1ef6ea25135fd33a7913944628b67f24c87db1f5.tar.gz
spark-1ef6ea25135fd33a7913944628b67f24c87db1f5.tar.bz2
spark-1ef6ea25135fd33a7913944628b67f24c87db1f5.zip
Added tests for testing network input stream.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala5
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala9
-rw-r--r--streaming/src/main/scala/spark/streaming/SocketInputDStream.scala17
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala14
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala112
5 files changed, 134 insertions, 23 deletions
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 38bb7c8b94..4bc063719c 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -93,8 +93,9 @@ extends Serializable with Logging {
* its parent DStreams.
*/
protected[streaming] def initialize(time: Time) {
- if (zeroTime != null) {
- throw new Exception("ZeroTime is already initialized, cannot initialize it again")
+ if (zeroTime != null && zeroTime != time) {
+ throw new Exception("ZeroTime is already initialized to " + zeroTime
+ + ", cannot initialize it again to " + time)
}
zeroTime = time
dependencies.foreach(_.initialize(zeroTime))
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
index 9b1b8813de..07ef79415d 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -108,10 +108,11 @@ class NetworkInputTracker(
}
def stopReceivers() {
- implicit val ec = env.actorSystem.dispatcher
- val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList
- val futureOfList = Future.sequence(listOfFutures)
- Await.result(futureOfList, timeout)
+ //implicit val ec = env.actorSystem.dispatcher
+ receiverInfo.values.foreach(_ ! StopReceiver)
+ //val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList
+ //val futureOfList = Future.sequence(listOfFutures)
+ //Await.result(futureOfList, timeout)
}
}
}
diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
index 4dbf421687..8ff7865ca4 100644
--- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
@@ -18,12 +18,12 @@ class SocketInputDStream[T: ClassManifest](
) extends NetworkInputDStream[T](ssc_) {
def createReceiver(): NetworkReceiver[T] = {
- new ObjectInputReceiver(id, host, port, bytesToObjects, storageLevel)
+ new SocketReceiver(id, host, port, bytesToObjects, storageLevel)
}
}
-class ObjectInputReceiver[T: ClassManifest](
+class SocketReceiver[T: ClassManifest](
streamId: Int,
host: String,
port: Int,
@@ -120,7 +120,12 @@ class ObjectInputReceiver[T: ClassManifest](
}
-object ObjectInputReceiver {
+object SocketReceiver {
+
+ /**
+ * This methods translates the data from an inputstream (say, from a socket)
+ * to '\n' delimited strings and returns an iterator to access the strings.
+ */
def bytesToLines(inputStream: InputStream): Iterator[String] = {
val bufferedInputStream = new BufferedInputStream(inputStream)
val dataInputStream = new DataInputStream(bufferedInputStream)
@@ -133,7 +138,11 @@ object ObjectInputReceiver {
private def getNext() {
try {
nextValue = dataInputStream.readLine()
- println("[" + nextValue + "]")
+ if (nextValue != null) {
+ println("[" + nextValue + "]")
+ } else {
+ gotNext = false
+ }
} catch {
case eof: EOFException =>
finished = true
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 228f1a3616..e124b8cfa0 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -91,7 +91,7 @@ class StreamingContext (
port: Int,
storageLevel: StorageLevel = StorageLevel.DISK_AND_MEMORY_2
): DStream[String] = {
- networkStream[String](hostname, port, ObjectInputReceiver.bytesToLines, storageLevel)
+ networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel)
}
def networkStream[T: ClassManifest](
@@ -115,18 +115,6 @@ class StreamingContext (
inputStream
}
- /*
- def createHttpTextStream(url: String): DStream[String] = {
- createHttpStream(url, ObjectInputReceiver.bytesToLines)
- }
-
- def createHttpStream[T: ClassManifest](
- url: String,
- converter: (InputStream) => Iterator[T]
- ): DStream[T] = {
- }
- */
-
/**
* This function creates a input stream that monitors a Hadoop-compatible
* for new files and executes the necessary processing on them.
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
new file mode 100644
index 0000000000..dd872059ea
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -0,0 +1,112 @@
+package spark.streaming
+
+import java.net.{SocketException, Socket, ServerSocket}
+import java.io.{BufferedWriter, OutputStreamWriter}
+import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
+import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import util.ManualClock
+import spark.storage.StorageLevel
+
+
+class InputStreamsSuite extends TestSuiteBase {
+
+ test("network input stream") {
+ val serverPort = 9999
+ val server = new TestServer(9999)
+ server.start()
+ val ssc = new StreamingContext(master, framework)
+ ssc.setBatchDuration(batchDuration)
+
+ val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.DISK_AND_MEMORY)
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]]
+ val outputStream = new TestOutputStream(networkStream, outputBuffer)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3)
+ val expectedOutput = input.map(_.toString)
+ for (i <- 0 until input.size) {
+ server.send(input(i).toString + "\n")
+ Thread.sleep(1000)
+ clock.addToTime(1000)
+ }
+ val startTime = System.currentTimeMillis()
+ while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size)
+ Thread.sleep(100)
+ }
+ Thread.sleep(5000)
+ val timeTaken = System.currentTimeMillis() - startTime
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+
+ ssc.stop()
+ server.stop()
+
+ assert(outputBuffer.size === expectedOutput.size)
+ for (i <- 0 until outputBuffer.size) {
+ assert(outputBuffer(i).size === 1)
+ assert(outputBuffer(i).head === expectedOutput(i))
+ }
+ }
+}
+
+
+class TestServer(port: Int) {
+
+ val queue = new ArrayBlockingQueue[String](100)
+
+ val serverSocket = new ServerSocket(port)
+
+ val servingThread = new Thread() {
+ override def run() {
+ try {
+ while(true) {
+ println("Accepting connections on port " + port)
+ val clientSocket = serverSocket.accept()
+ println("New connection")
+ try {
+ clientSocket.setTcpNoDelay(true)
+ val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream))
+
+ while(clientSocket.isConnected) {
+ val msg = queue.poll(100, TimeUnit.MILLISECONDS)
+ if (msg != null) {
+ outputStream.write(msg)
+ outputStream.flush()
+ println("Message '" + msg + "' sent")
+ }
+ }
+ } catch {
+ case e: SocketException => println(e)
+ } finally {
+ println("Connection closed")
+ if (!clientSocket.isClosed) clientSocket.close()
+ }
+ }
+ } catch {
+ case ie: InterruptedException =>
+
+ } finally {
+ serverSocket.close()
+ }
+ }
+ }
+
+ def start() { servingThread.start() }
+
+ def send(msg: String) { queue.add(msg) }
+
+ def stop() { servingThread.interrupt() }
+}
+
+object TestServer {
+ def main(args: Array[String]) {
+ val s = new TestServer(9999)
+ s.start()
+ while(true) {
+ Thread.sleep(1000)
+ s.send("hello")
+ }
+ }
+}