From a6de5758f1a48e6c25b441440d8cd84546857326 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Oct 2012 01:41:13 -0700 Subject: Modified API of NetworkInputDStreams and got ObjectInputDStream and RawInputDStream working. --- .../scala/spark/streaming/FileInputDStream.scala | 18 -- .../spark/streaming/NetworkInputDStream.scala | 139 +++++++++++- .../streaming/NetworkInputReceiverMessage.scala | 7 - .../spark/streaming/NetworkInputTracker.scala | 84 +++---- .../scala/spark/streaming/ObjectInputDStream.scala | 169 +++++++++++++- .../spark/streaming/ObjectInputReceiver.scala | 244 --------------------- .../scala/spark/streaming/RawInputDStream.scala | 77 ++----- .../scala/spark/streaming/StreamingContext.scala | 4 +- 8 files changed, 360 insertions(+), 382 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala (limited to 'streaming') diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 29ae89616e..78537b8794 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -19,15 +19,6 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null - /* - @transient @noinline lazy val path = { - //if (directory == null) throw new Exception("directory is null") - //println(directory) - new Path(directory) - } - @transient lazy val fs = path.getFileSystem(new Configuration()) - */ - var lastModTime: Long = 0 def path(): Path = { @@ -79,15 +70,6 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) Some(newRDD) } - /* - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - println(this.getClass().getSimpleName + ".readObject used") - ois.defaultReadObject() - println("HERE HERE" + this.directory) - } - */ - } object FileInputDStream { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index bf83f98ec4..6b41e4d2c8 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -1,13 +1,30 @@ package spark.streaming -import spark.RDD -import spark.BlockRDD +import spark.{Logging, SparkEnv, RDD, BlockRDD} +import spark.storage.StorageLevel -abstract class NetworkInputDStream[T: ClassManifest](@transient ssc: StreamingContext) - extends InputDStream[T](ssc) { +import java.nio.ByteBuffer - val id = ssc.getNewNetworkStreamId() - +import akka.actor.{Props, Actor} +import akka.pattern.ask +import akka.dispatch.Await +import akka.util.duration._ + +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) + extends InputDStream[T](ssc_) { + + // This is an unique identifier that is used to match the network receiver with the + // corresponding network input stream. + val id = ssc.getNewNetworkStreamId() + + /** + * This method creates the receiver object that will be sent to the workers + * to receive data. This method needs to defined by any specific implementation + * of a NetworkInputDStream. + */ + def createReceiver(): NetworkReceiver[T] + + // Nothing to start or stop as both taken care of by the NetworkInputTracker. def start() {} def stop() {} @@ -16,8 +33,114 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc: StreamingCo val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } +} + + +sealed trait NetworkReceiverMessage +case class StopReceiver(msg: String) extends NetworkReceiverMessage +case class ReportBlock(blockId: String) extends NetworkReceiverMessage +case class ReportError(msg: String) extends NetworkReceiverMessage + +abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging { + + initLogging() + + lazy protected val env = SparkEnv.get + + lazy protected val actor = env.actorSystem.actorOf( + Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + + lazy protected val receivingThread = Thread.currentThread() + + /** This method will be called to start receiving data. */ + protected def onStart() + + /** This method will be called to stop receiving data. */ + protected def onStop() + + /** + * This method starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc required to receiver the data. + */ + def start() { + try { + // Access the lazy vals to materialize them + env + actor + receivingThread + + // Call user-defined onStart() + onStart() + } catch { + case ie: InterruptedException => + logWarning("Receiving thread interrupted") + case e: Exception => + stopOnError(e) + } + } + + /** + * This method stops the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). Then it calls the user-defined + * onStop() method to stop other threads and/or do cleanup. + */ + def stop() { + receivingThread.interrupt() + onStop() + //TODO: terminate the actor + } + + /** + * This method stops the receiver and reports to exception to the tracker. + * This should be called whenever an exception has happened on any thread + * of the receiver. + */ + protected def stopOnError(e: Exception) { + logError("Error receiving data", e) + stop() + actor ! ReportError(e.toString) + } - /** Called on workers to run a receiver for the stream. */ - def runReceiver(): Unit + /** + * This method pushes a block (as iterator of values) into the block manager. + */ + protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) { + env.blockManager.put(blockId, iterator, level) + actor ! ReportBlock(blockId) + } + + /** + * This method pushes a block (as bytes) into the block manager. + */ + protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + env.blockManager.putBytes(blockId, bytes, level) + actor ! ReportBlock(blockId) + } + + /** A helper actor that communicates with the NetworkInputTracker */ + private class NetworkReceiverActor extends Actor { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + val tracker = env.actorSystem.actorFor(url) + val timeout = 5.seconds + + override def preStart() { + val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive() = { + case ReportBlock(blockId) => + tracker ! AddBlocks(streamId, Array(blockId)) + case ReportError(msg) => + tracker ! DeregisterReceiver(streamId, msg) + case StopReceiver(msg) => + stop() + tracker ! DeregisterReceiver(streamId, msg) + } + } } diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala deleted file mode 100644 index deaffe98c8..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala +++ /dev/null @@ -1,7 +0,0 @@ -package spark.streaming - -sealed trait NetworkInputReceiverMessage - -case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage -case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage -case object StopReceiver extends NetworkInputReceiverMessage diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 9f9001e4d5..9b1b8813de 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -13,13 +13,44 @@ import akka.dispatch._ trait NetworkInputTrackerMessage case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage +case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage + class NetworkInputTracker( @transient ssc: StreamingContext, - @transient networkInputStreams: Array[NetworkInputDStream[_]]) -extends Logging { + @transient networkInputStreams: Array[NetworkInputDStream[_]]) + extends Logging { + + val networkInputStreamIds = networkInputStreams.map(_.id).toArray + val receiverExecutor = new ReceiverExecutor() + val receiverInfo = new HashMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds - class TrackerActor extends Actor { + var currentTime: Time = null + + def start() { + ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + } + + def stop() { + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + } + + def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) + } + result.toArray + } + + private class NetworkInputTrackerActor extends Actor { def receive = { case RegisterReceiver(streamId, receiverActor) => { if (!networkInputStreamIds.contains(streamId)) { @@ -29,7 +60,7 @@ extends Logging { logInfo("Registered receiver for network stream " + streamId) sender ! true } - case GotBlockIds(streamId, blockIds) => { + case AddBlocks(streamId, blockIds) => { val tmp = receivedBlockIds.synchronized { if (!receivedBlockIds.contains(streamId)) { receivedBlockIds += ((streamId, new Queue[String])) @@ -40,6 +71,12 @@ extends Logging { tmp ++= blockIds } } + case DeregisterReceiver(streamId, msg) => { + receiverInfo -= streamId + logInfo("De-registered receiver for network stream " + streamId + + " with message " + msg) + //TODO: Do something about the corresponding NetworkInputDStream + } } } @@ -58,15 +95,15 @@ extends Logging { } def startReceivers() { - val tempRDD = ssc.sc.makeRDD(networkInputStreams, networkInputStreams.size) - - val startReceiver = (iterator: Iterator[NetworkInputDStream[_]]) => { + val receivers = networkInputStreams.map(_.createReceiver()) + val tempRDD = ssc.sc.makeRDD(receivers, receivers.size) + + val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { if (!iterator.hasNext) { throw new Exception("Could not start receiver as details not found.") } - iterator.next().runReceiver() + iterator.next().start() } - ssc.sc.runJob(tempRDD, startReceiver) } @@ -77,33 +114,4 @@ extends Logging { Await.result(futureOfList, timeout) } } - - val networkInputStreamIds = networkInputStreams.map(_.id).toArray - val receiverExecutor = new ReceiverExecutor() - val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[String]] - val timeout = 5000.milliseconds - - - var currentTime: Time = null - - def start() { - ssc.env.actorSystem.actorOf(Props(new TrackerActor), "NetworkInputTracker") - receiverExecutor.start() - } - - def stop() { - // stop the actor - receiverExecutor.interrupt() - } - - def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { - val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[String]()) - } - val result = queue.synchronized { - queue.dequeueAll(x => true) - } - result.toArray - } } diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala index 2396b374a0..89aeeda8b3 100644 --- a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala @@ -1,16 +1,167 @@ package spark.streaming -import java.io.InputStream +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + +import java.io.{EOFException, DataInputStream, BufferedInputStream, InputStream} +import java.net.Socket +import java.util.concurrent.ArrayBlockingQueue + +import scala.collection.mutable.ArrayBuffer class ObjectInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val host: String, - val port: Int, - val bytesToObjects: InputStream => Iterator[T]) - extends NetworkInputDStream[T](ssc) { - - override def runReceiver() { - new ObjectInputReceiver(id, host, port, bytesToObjects).run() + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def createReceiver(): NetworkReceiver[T] = { + new ObjectInputReceiver(id, host, port, bytesToObjects, storageLevel) } } + +class ObjectInputReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T](streamId) { + + lazy protected val dataHandler = new DataHandler(this) + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + dataHandler.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } + + protected def onStop() { + dataHandler.stop() + } + + /** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { + case class Block(id: String, iterator: Iterator[T]) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + streamId + "- " + (time - blockInterval) + val newBlock = new Block(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + pushBlock(block.id, block.iterator, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } +} + + +object ObjectInputReceiver { + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val bufferedInputStream = new BufferedInputStream(inputStream) + val dataInputStream = new DataInputStream(bufferedInputStream) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + println("[" + nextValue + "]") + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + dataInputStream.close() + } + !finished + } + + override def next(): String = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } + iterator + } +} diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala deleted file mode 100644 index 70fa2cdf07..0000000000 --- a/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala +++ /dev/null @@ -1,244 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.BlockManager -import spark.storage.StorageLevel -import spark.SparkEnv -import spark.streaming.util.SystemClock -import spark.streaming.util.RecurringTimer - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Queue -import scala.collection.mutable.SynchronizedPriorityQueue -import scala.math.Ordering - -import java.net.InetSocketAddress -import java.net.Socket -import java.io.InputStream -import java.io.BufferedInputStream -import java.io.DataInputStream -import java.io.EOFException -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.ArrayBlockingQueue - -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - -class ObjectInputReceiver[T: ClassManifest]( - streamId: Int, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T]) - extends Logging { - - class ReceiverActor extends Actor { - override def preStart() { - logInfo("Attempting to register") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 1.seconds - val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - def receive = { - case GetBlockIds(time) => { - logInfo("Got request for block ids for " + time) - sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) - } - - case StopReceiver => { - if (receivingThread != null) { - receivingThread.interrupt() - } - sender ! true - } - } - } - - class DataHandler { - class Block(val time: Long, val iterator: Iterator[T]) { - val blockId = "input-" + streamId + "-" + time - var pushed = true - override def toString = "input block " + blockId - } - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockOrdering = new Ordering[Block] { - def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt - } - val blockStorageLevel = StorageLevel.DISK_AND_MEMORY - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - blocksForReporting.enqueue(newBlock) - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - if (blockManager != null) { - blockManager.put(block.blockId, block.iterator, blockStorageLevel) - block.pushed = true - } else { - logWarning(block + " not put as block manager is null") - } - } - } catch { - case ie: InterruptedException => println("Block pushing thread interrupted") - case e: Exception => e.printStackTrace() - } - } - - def getPushedBlocks(): Array[String] = { - val pushedBlocks = new ArrayBuffer[String]() - var loop = true - while(loop && !blocksForReporting.isEmpty) { - val block = blocksForReporting.dequeue() - if (block == null) { - loop = false - } else if (!block.pushed) { - blocksForReporting.enqueue(block) - } else { - pushedBlocks += block.blockId - } - } - logInfo("Got " + pushedBlocks.size + " blocks") - pushedBlocks.toArray - } - } - - val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null - val dataHandler = new DataHandler() - val env = SparkEnv.get - - var receiverActor: ActorRef = null - var receivingThread: Thread = null - - def run() { - initLogging() - var socket: Socket = null - try { - if (SparkEnv.get != null) { - receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) - } - dataHandler.start() - socket = connect() - receivingThread = Thread.currentThread() - receive(socket) - } catch { - case ie: InterruptedException => logInfo("Receiver interrupted") - } finally { - receivingThread = null - if (socket != null) socket.close() - dataHandler.stop() - } - } - - def connect(): Socket = { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - socket - } - - def receive(socket: Socket) { - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } -} - - -object ObjectInputReceiver { - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val bufferedInputStream = new BufferedInputStream(inputStream) - val dataInputStream = new DataInputStream(bufferedInputStream) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - println("[" + nextValue + "]") - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - dataInputStream.close() - } - !finished - } - - override def next(): String = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } - iterator - } - - def main(args: Array[String]) { - if (args.length < 2) { - println("ObjectInputReceiver ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val receiver = new ObjectInputReceiver(0, host, port, bytesToLines) - receiver.run() - } -} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index d29aea7886..e022b85fbe 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -1,16 +1,11 @@ package spark.streaming -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, SocketChannel} import java.io.EOFException import java.util.concurrent.ArrayBlockingQueue -import scala.collection.mutable.ArrayBuffer -import spark.{DaemonThread, Logging, SparkEnv} +import spark._ import spark.storage.StorageLevel /** @@ -20,20 +15,23 @@ import spark.storage.StorageLevel * in the format that the system is configured with. */ class RawInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, + @transient ssc_ : StreamingContext, host: String, port: Int, - storageLevel: StorageLevel) - extends NetworkInputDStream[T](ssc) with Logging { + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { - val streamId = id + def createReceiver(): NetworkReceiver[T] = { + new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + } +} + +class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any](streamId) { - /** Called on workers to run a receiver for the stream. */ - def runReceiver() { - val env = SparkEnv.get - val actor = env.actorSystem.actorOf( - Props(new ReceiverActor(env, Thread.currentThread)), "ReceiverActor-" + streamId) + var blockPushingThread: Thread = null + def onStart() { // Open a socket to the target address and keep reading from it logInfo("Connecting to " + host + ":" + port) val channel = SocketChannel.open() @@ -43,18 +41,18 @@ class RawInputDStream[T: ClassManifest]( val queue = new ArrayBlockingQueue[ByteBuffer](2) - new DaemonThread { + blockPushingThread = new DaemonThread { override def run() { var nextBlockNumber = 0 while (true) { val buffer = queue.take() val blockId = "input-" + streamId + "-" + nextBlockNumber nextBlockNumber += 1 - env.blockManager.putBytes(blockId, buffer, storageLevel) - actor ! BlockPublished(blockId) + pushBlock(blockId, buffer, storageLevel) } } - }.start() + } + blockPushingThread.start() val lengthBuffer = ByteBuffer.allocate(4) while (true) { @@ -70,6 +68,10 @@ class RawInputDStream[T: ClassManifest]( } } + def onStop() { + blockPushingThread.interrupt() + } + /** Read a buffer fully from a given Channel */ private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { while (dest.position < dest.limit) { @@ -78,41 +80,4 @@ class RawInputDStream[T: ClassManifest]( } } } - - /** Message sent to ReceiverActor to tell it that a block was published */ - case class BlockPublished(blockId: String) {} - - /** A helper actor that communicates with the NetworkInputTracker */ - private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { - val newBlocks = new ArrayBuffer[String] - - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 5.seconds - - override def preStart() { - val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - override def receive = { - case BlockPublished(blockId) => - newBlocks += blockId - val future = trackerActor ! GotBlockIds(streamId, Array(blockId)) - - case GetBlockIds(time) => - logInfo("Got request for block IDs for " + time) - sender ! GotBlockIds(streamId, newBlocks.toArray) - newBlocks.clear() - - case StopReceiver => - receivingThread.interrupt() - sender ! true - } - - } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 7022056f7c..1dc5614a5c 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -95,7 +95,7 @@ class StreamingContext ( port: Int, converter: (InputStream) => Iterator[T] ): DStream[T] = { - val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) + val inputStream = new ObjectInputDStream[T](this, hostname, port, converter, StorageLevel.DISK_AND_MEMORY_2) graph.addInputStream(inputStream) inputStream } @@ -207,7 +207,7 @@ class StreamingContext ( } /** - * This function starts the execution of the streams. + * This function stops the execution of the streams. */ def stop() { try { -- cgit v1.2.3