diff options
author | Denny <dennybritz@gmail.com> | 2012-11-09 12:23:46 -0800 |
---|---|---|
committer | Denny <dennybritz@gmail.com> | 2012-11-09 12:23:46 -0800 |
commit | e5a09367870be757a0abb3e2ad7a53e74110b033 (patch) | |
tree | ef5b73884102b1bdfd35aacbca9f7b73924b2f2c /streaming/src | |
parent | 485803d740307e03beee056390b0ecb0a76fbbb1 (diff) | |
download | spark-e5a09367870be757a0abb3e2ad7a53e74110b033.tar.gz spark-e5a09367870be757a0abb3e2ad7a53e74110b033.tar.bz2 spark-e5a09367870be757a0abb3e2ad7a53e74110b033.zip |
Kafka Stream.
Diffstat (limited to 'streaming/src')
9 files changed, 245 insertions, 161 deletions
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 922ff5088d..f891730317 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -17,6 +17,8 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration +case class DStreamCheckpointData(rdds: HashMap[Time, Any]) + abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -59,7 +61,7 @@ extends Serializable with Logging { // Checkpoint details protected[streaming] val mustCheckpoint = false protected[streaming] var checkpointInterval: Time = null - protected[streaming] val checkpointData = new HashMap[Time, Any]() + protected[streaming] var checkpointData = DStreamCheckpointData(HashMap[Time, Any]()) // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null @@ -280,6 +282,13 @@ extends Serializable with Logging { dependencies.foreach(_.forgetOldRDDs(time)) } + /* Adds metadata to the Stream while it is running. + * This methd should be overwritten by sublcasses of InputDStream. + */ + protected[streaming] def addMetadata(metadata: Any) { + logInfo("Dropping Metadata: " + metadata.toString) + } + /** * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of * this stream. This is an internal method that should not be called directly. This is @@ -288,22 +297,22 @@ extends Serializable with Logging { * this method to save custom checkpoint data. */ protected[streaming] def updateCheckpointData(currentTime: Time) { - val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) + val newRdds = generatedRDDs.filter(_._2.getCheckpointData() != null) .map(x => (x._1, x._2.getCheckpointData())) - val oldCheckpointData = checkpointData.clone() - if (newCheckpointData.size > 0) { - checkpointData.clear() - checkpointData ++= newCheckpointData + val oldRdds = checkpointData.rdds.clone() + if (newRdds.size > 0) { + checkpointData.rdds.clear() + checkpointData.rdds ++= newRdds } dependencies.foreach(_.updateCheckpointData(currentTime)) - newCheckpointData.foreach { + newRdds.foreach { case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } } - if (newCheckpointData.size > 0) { - (oldCheckpointData -- newCheckpointData.keySet).foreach { + if (newRdds.size > 0) { + (oldRdds -- newRdds.keySet).foreach { case (time, data) => { val path = new Path(data.toString) val fs = path.getFileSystem(new Configuration()) @@ -322,8 +331,8 @@ extends Serializable with Logging { * override the updateCheckpointData() method would also need to override this method. */ protected[streaming] def restoreCheckpointData() { - logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") - checkpointData.foreach { + logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs") + checkpointData.rdds.foreach { case(time, data) => { logInfo("Restoring checkpointed RDD for time " + time + " from file") generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala new file mode 100644 index 0000000000..05f307a8d1 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DataHandler.scala @@ -0,0 +1,83 @@ +package spark.streaming + +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + + +/** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + }
\ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index f3f4c3ab13..d3f37b8b0e 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -4,9 +4,11 @@ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkEnv, RDD} import spark.rdd.BlockRDD +import spark.streaming.util.{RecurringTimer, SystemClock} import spark.storage.StorageLevel import java.nio.ByteBuffer +import java.util.concurrent.ArrayBlockingQueue import akka.actor.{Props, Actor} import akka.pattern.ask @@ -41,10 +43,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming sealed trait NetworkReceiverMessage case class StopReceiver(msg: String) extends NetworkReceiverMessage -case class ReportBlock(blockId: String) extends NetworkReceiverMessage +case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage case class ReportError(msg: String) extends NetworkReceiverMessage -abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging { +abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { initLogging() @@ -106,21 +108,23 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ actor ! ReportError(e.toString) } + /** * This method pushes a block (as iterator of values) into the block manager. */ - protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) { + def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { val buffer = new ArrayBuffer[T] ++ iterator env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) - actor ! ReportBlock(blockId) + + actor ! ReportBlock(blockId, metadata) } /** * This method pushes a block (as bytes) into the block manager. */ - protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId) + actor ! ReportBlock(blockId, metadata) } /** A helper actor that communicates with the NetworkInputTracker */ @@ -138,8 +142,8 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ } override def receive() = { - case ReportBlock(blockId) => - tracker ! AddBlocks(streamId, Array(blockId)) + case ReportBlock(blockId, metadata) => + tracker ! AddBlocks(streamId, Array(blockId), metadata) case ReportError(msg) => tracker ! DeregisterReceiver(streamId, msg) case StopReceiver(msg) => @@ -147,5 +151,6 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ tracker ! DeregisterReceiver(streamId, msg) } } + } diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 07ef79415d..4d9346edd8 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -13,7 +13,7 @@ import akka.dispatch._ trait NetworkInputTrackerMessage case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage +case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage @@ -22,7 +22,7 @@ class NetworkInputTracker( @transient networkInputStreams: Array[NetworkInputDStream[_]]) extends Logging { - val networkInputStreamIds = networkInputStreams.map(_.id).toArray + val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] val receivedBlockIds = new HashMap[Int, Queue[String]] @@ -53,14 +53,14 @@ class NetworkInputTracker( private class NetworkInputTrackerActor extends Actor { def receive = { case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamIds.contains(streamId)) { + if (!networkInputStreamMap.contains(streamId)) { throw new Exception("Register received for unexpected id " + streamId) } receiverInfo += ((streamId, receiverActor)) logInfo("Registered receiver for network stream " + streamId) sender ! true } - case AddBlocks(streamId, blockIds) => { + case AddBlocks(streamId, blockIds, metadata) => { val tmp = receivedBlockIds.synchronized { if (!receivedBlockIds.contains(streamId)) { receivedBlockIds += ((streamId, new Queue[String])) @@ -70,6 +70,7 @@ class NetworkInputTracker( tmp.synchronized { tmp ++= blockIds } + networkInputStreamMap(streamId).addMetadata(metadata) } case DeregisterReceiver(streamId, msg) => { receiverInfo -= streamId diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index e022b85fbe..90d8528d5b 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -48,7 +48,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S val buffer = queue.take() val blockId = "input-" + streamId + "-" + nextBlockNumber nextBlockNumber += 1 - pushBlock(blockId, buffer, storageLevel) + pushBlock(blockId, buffer, null, storageLevel) } } } diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index b566200273..ff99d50b76 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -32,7 +32,7 @@ class SocketReceiver[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkReceiver[T](streamId) { - lazy protected val dataHandler = new DataHandler(this) + lazy protected val dataHandler = new DataHandler(this, storageLevel) protected def onStart() { logInfo("Connecting to " + host + ":" + port) @@ -50,74 +50,6 @@ class SocketReceiver[T: ClassManifest]( dataHandler.stop() } - /** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { - case class Block(id: String, iterator: Iterator[T]) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + streamId + "- " + (time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - pushBlock(block.id, block.iterator, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 05c83d6c08..770fd61498 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -106,9 +106,11 @@ class StreamingContext ( hostname: String, port: Int, groupId: String, + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, storageLevel) + val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) graph.addInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 3f637150d1..655f9627b3 100644 --- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -1,6 +1,6 @@ package spark.streaming.examples -import spark.streaming.{Seconds, StreamingContext, KafkaInputDStream} +import spark.streaming._ import spark.streaming.StreamingContext._ import spark.storage.StorageLevel @@ -17,11 +17,20 @@ object KafkaWordCount { // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group") + ssc.checkpoint("checkpoint", Time(1000 * 5)) + val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group", Map("test" -> 1), + Map(KafkaPartitionKey(0, "test", "test_group", 0) -> 2382)) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() ssc.start() + // Wait for 12 seconds + Thread.sleep(12000) + ssc.stop() + + val newSsc = new StreamingContext("checkpoint") + newSsc.start() + } } diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index 427f398237..814f2706d6 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -1,121 +1,164 @@ package spark.streaming +import java.lang.reflect.Method import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ArrayBlockingQueue, Executors} +import java.util.concurrent.{ArrayBlockingQueue, ConcurrentHashMap, Executors} import kafka.api.{FetchRequest} -import kafka.consumer.{Consumer, ConsumerConfig, KafkaStream} -import kafka.javaapi.consumer.SimpleConsumer -import kafka.javaapi.message.ByteBufferMessageSet +import kafka.consumer._ +import kafka.cluster.Partition import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.utils.Utils +import kafka.serializer.StringDecoder +import kafka.utils.{Pool, Utils, ZKGroupTopicDirs} +import kafka.utils.ZkUtils._ +import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ import spark._ import spark.RDD import spark.storage.StorageLevel +case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], + savedOffsets: HashMap[Long, Map[KafkaPartitionKey, Long]]) extends DStreamCheckpointData(kafkaRdds) + /** - * An input stream that pulls messages form a Kafka Broker. + * Input stream that pulls messages form a Kafka Broker. */ class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, host: String, port: Int, groupId: String, - storageLevel: StorageLevel, - timeout: Int = 10000, - bufferSize: Int = 1024000 + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { + var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + + override protected[streaming] def addMetadata(metadata: Any) { + metadata match { + case x : KafkaInputDStreamMetadata => + savedOffsets(x.timestamp) = x.data + logInfo("Saved Offsets: " + savedOffsets) + case _ => logInfo("Received unknown metadata: " + metadata.toString) + } + } + + override protected[streaming] def updateCheckpointData(currentTime: Time) { + super.updateCheckpointData(currentTime) + logInfo("Updating KafkaDStream checkpoint data: " + savedOffsets.toString) + checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, savedOffsets) + } + + override protected[streaming] def restoreCheckpointData() { + super.restoreCheckpointData() + logInfo("Restoring KafkaDStream checkpoint data.") + checkpointData match { + case x : KafkaDStreamCheckpointData => + savedOffsets = x.savedOffsets + logInfo("Restored KafkaDStream offsets: " + savedOffsets.toString) + } + } + def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(id, host, port, storageLevel, groupId, timeout).asInstanceOf[NetworkReceiver[T]] + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] } } -class KafkaReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel, groupId: String, timeout: Int) - extends NetworkReceiver[Any](streamId) { +class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, + topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { + + // Timeout for establishing a connection to Zookeper in ms. + val ZK_TIMEOUT = 10000 - //var executorPool : = null - var blockPushingThread : Thread = null + // Handles pushing data into the BlockManager + lazy protected val dataHandler = new KafkaDataHandler(this, storageLevel) + // Keeps track of the current offsets. Maps from (topic, partitionID) -> Offset + lazy val offsets = HashMap[KafkaPartitionKey, Long]() + // Connection to Kafka + var consumerConnector : ZookeeperConsumerConnector = null def onStop() { - blockPushingThread.interrupt() + dataHandler.stop() } def onStart() { - val executorPool = Executors.newFixedThreadPool(2) + // Starting the DataHandler that buffers blocks and pushes them into them BlockManager + dataHandler.start() - logInfo("Starting Kafka Consumer with groupId " + groupId) + // In case we are using multiple Threads to handle Kafka Messages + val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) val zooKeeperEndPoint = host + ":" + port + logInfo("Starting Kafka Consumer Stream in group " + groupId) + logInfo("Initial offsets: " + initialOffsets.toString) logInfo("Connecting to " + zooKeeperEndPoint) - - // Specify some consumer properties + // Specify some Consumer properties val props = new Properties() props.put("zk.connect", zooKeeperEndPoint) - props.put("zk.connectiontimeout.ms", timeout.toString) + props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) props.put("groupid", groupId) // Create the connection to the cluster val consumerConfig = new ConsumerConfig(props) - val consumerConnector = Consumer.create(consumerConfig) - logInfo("Connected to " + zooKeeperEndPoint) - logInfo("") - logInfo("") + consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - // Specify which topics we are listening to - val topicCountMap = Map("test" -> 2) - val topicMessageStreams = consumerConnector.createMessageStreams(topicCountMap) - val streams = topicMessageStreams.get("test") + // Reset the Kafka offsets in case we are recovering from a failure + resetOffsets(initialOffsets) - // Queue that holds the blocks - val queue = new ArrayBlockingQueue[ByteBuffer](2) + logInfo("Connected to " + zooKeeperEndPoint) - streams.getOrElse(Nil).foreach { stream => - executorPool.submit(new MessageHandler(stream, queue)) + // Create Threads for each Topic/Message Stream we are listening + val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } } - blockPushingThread = new DaemonThread { - override def run() { - logInfo("Starting BlockPushingThread.") - var nextBlockNumber = 0 - while (true) { - val buffer = queue.take() - val blockId = "input-" + streamId + "-" + nextBlockNumber - nextBlockNumber += 1 - pushBlock(blockId, buffer, storageLevel) - } - } + } + + // Overwrites the offets in Zookeper. + private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + offsets.foreach { case(key, offset) => + val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) + val partitionName = key.brokerId + "-" + key.partId + updatePersistentPath(consumerConnector.zkClient, + topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) } - blockPushingThread.start() - - // while (true) { - // // Create a fetch request for topic “test”, partition 0, current offset, and fetch size of 1MB - // val fetchRequest = new FetchRequest("test", 0, offset, 1000000) - - // // get the message set from the consumer and print them out - // val messages = consumer.fetch(fetchRequest) - // for(msg <- messages.iterator) { - // logInfo("consumed: " + Utils.toString(msg.message.payload, "UTF-8")) - // // advance the offset after consuming each message - // offset = msg.offset - // queue.put(msg.message.payload) - // } - // } } - class MessageHandler(stream: KafkaStream[Message], queue: ArrayBlockingQueue[ByteBuffer]) extends Runnable { + // Responsible for handling Kafka Messages + class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { logInfo("Starting MessageHandler.") - while(true) { - stream.foreach { msgAndMetadata => - logInfo("Consumed: " + Utils.toString(msgAndMetadata.message.payload, "UTF-8")) - queue.put(msgAndMetadata.message.payload) - } - } + stream.takeWhile { msgAndMetadata => + dataHandler += msgAndMetadata.message + + // Updating the offet. The key is (topic, partitionID). + val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, + groupId, msgAndMetadata.topicInfo.partition.partId) + val offset = msgAndMetadata.topicInfo.getConsumeOffset + offsets.put(key, offset) + logInfo((key, offset).toString) + + // Keep on handling messages + true + } } } + class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) + extends DataHandler[Any](receiver, storageLevel) { + + override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) + } + + } } |