aboutsummaryrefslogtreecommitdiff
path: root/streaming/src
diff options
context:
space:
mode:
authorDenny <dennybritz@gmail.com>2012-11-09 12:23:46 -0800
committerDenny <dennybritz@gmail.com>2012-11-09 12:23:46 -0800
commite5a09367870be757a0abb3e2ad7a53e74110b033 (patch)
treeef5b73884102b1bdfd35aacbca9f7b73924b2f2c /streaming/src
parent485803d740307e03beee056390b0ecb0a76fbbb1 (diff)
downloadspark-e5a09367870be757a0abb3e2ad7a53e74110b033.tar.gz
spark-e5a09367870be757a0abb3e2ad7a53e74110b033.tar.bz2
spark-e5a09367870be757a0abb3e2ad7a53e74110b033.zip
Kafka Stream.
Diffstat (limited to 'streaming/src')
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala31
-rw-r--r--streaming/src/main/scala/spark/streaming/DataHandler.scala83
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala9
-rw-r--r--streaming/src/main/scala/spark/streaming/RawInputDStream.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/SocketInputDStream.scala70
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala13
-rw-r--r--streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala173
9 files changed, 245 insertions, 161 deletions
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 922ff5088d..f891730317 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -17,6 +17,8 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
+case class DStreamCheckpointData(rdds: HashMap[Time, Any])
+
abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
extends Serializable with Logging {
@@ -59,7 +61,7 @@ extends Serializable with Logging {
// Checkpoint details
protected[streaming] val mustCheckpoint = false
protected[streaming] var checkpointInterval: Time = null
- protected[streaming] val checkpointData = new HashMap[Time, Any]()
+ protected[streaming] var checkpointData = DStreamCheckpointData(HashMap[Time, Any]())
// Reference to whole DStream graph
protected[streaming] var graph: DStreamGraph = null
@@ -280,6 +282,13 @@ extends Serializable with Logging {
dependencies.foreach(_.forgetOldRDDs(time))
}
+ /* Adds metadata to the Stream while it is running.
+ * This methd should be overwritten by sublcasses of InputDStream.
+ */
+ protected[streaming] def addMetadata(metadata: Any) {
+ logInfo("Dropping Metadata: " + metadata.toString)
+ }
+
/**
* Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of
* this stream. This is an internal method that should not be called directly. This is
@@ -288,22 +297,22 @@ extends Serializable with Logging {
* this method to save custom checkpoint data.
*/
protected[streaming] def updateCheckpointData(currentTime: Time) {
- val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null)
+ val newRdds = generatedRDDs.filter(_._2.getCheckpointData() != null)
.map(x => (x._1, x._2.getCheckpointData()))
- val oldCheckpointData = checkpointData.clone()
- if (newCheckpointData.size > 0) {
- checkpointData.clear()
- checkpointData ++= newCheckpointData
+ val oldRdds = checkpointData.rdds.clone()
+ if (newRdds.size > 0) {
+ checkpointData.rdds.clear()
+ checkpointData.rdds ++= newRdds
}
dependencies.foreach(_.updateCheckpointData(currentTime))
- newCheckpointData.foreach {
+ newRdds.foreach {
case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
}
- if (newCheckpointData.size > 0) {
- (oldCheckpointData -- newCheckpointData.keySet).foreach {
+ if (newRdds.size > 0) {
+ (oldRdds -- newRdds.keySet).foreach {
case (time, data) => {
val path = new Path(data.toString)
val fs = path.getFileSystem(new Configuration())
@@ -322,8 +331,8 @@ extends Serializable with Logging {
* override the updateCheckpointData() method would also need to override this method.
*/
protected[streaming] def restoreCheckpointData() {
- logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs")
- checkpointData.foreach {
+ logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs")
+ checkpointData.rdds.foreach {
case(time, data) => {
logInfo("Restoring checkpointed RDD for time " + time + " from file")
generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString)))
diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala
new file mode 100644
index 0000000000..05f307a8d1
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DataHandler.scala
@@ -0,0 +1,83 @@
+package spark.streaming
+
+import java.util.concurrent.ArrayBlockingQueue
+import scala.collection.mutable.ArrayBuffer
+import spark.Logging
+import spark.streaming.util.{RecurringTimer, SystemClock}
+import spark.storage.StorageLevel
+
+
+/**
+ * This is a helper object that manages the data received from the socket. It divides
+ * the object received into small batches of 100s of milliseconds, pushes them as
+ * blocks into the block manager and reports the block IDs to the network input
+ * tracker. It starts two threads, one to periodically start a new batch and prepare
+ * the previous batch of as a block, the other to push the blocks into the block
+ * manager.
+ */
+ class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel)
+ extends Serializable with Logging {
+
+ case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
+
+ val clock = new SystemClock()
+ val blockInterval = 200L
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockStorageLevel = storageLevel
+ val blocksForPushing = new ArrayBlockingQueue[Block](1000)
+ val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
+
+ var currentBuffer = new ArrayBuffer[T]
+
+ def createBlock(blockId: String, iterator: Iterator[T]) : Block = {
+ new Block(blockId, iterator)
+ }
+
+ def start() {
+ blockIntervalTimer.start()
+ blockPushingThread.start()
+ logInfo("Data handler started")
+ }
+
+ def stop() {
+ blockIntervalTimer.stop()
+ blockPushingThread.interrupt()
+ logInfo("Data handler stopped")
+ }
+
+ def += (obj: T) {
+ currentBuffer += obj
+ }
+
+ def updateCurrentBuffer(time: Long) {
+ try {
+ val newBlockBuffer = currentBuffer
+ currentBuffer = new ArrayBuffer[T]
+ if (newBlockBuffer.size > 0) {
+ val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval)
+ val newBlock = createBlock(blockId, newBlockBuffer.toIterator)
+ blocksForPushing.add(newBlock)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block interval timer thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+
+ def keepPushingBlocks() {
+ logInfo("Block pushing thread started")
+ try {
+ while(true) {
+ val block = blocksForPushing.take()
+ receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block pushing thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+ } \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
index f3f4c3ab13..d3f37b8b0e 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
@@ -4,9 +4,11 @@ import scala.collection.mutable.ArrayBuffer
import spark.{Logging, SparkEnv, RDD}
import spark.rdd.BlockRDD
+import spark.streaming.util.{RecurringTimer, SystemClock}
import spark.storage.StorageLevel
import java.nio.ByteBuffer
+import java.util.concurrent.ArrayBlockingQueue
import akka.actor.{Props, Actor}
import akka.pattern.ask
@@ -41,10 +43,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
sealed trait NetworkReceiverMessage
case class StopReceiver(msg: String) extends NetworkReceiverMessage
-case class ReportBlock(blockId: String) extends NetworkReceiverMessage
+case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
case class ReportError(msg: String) extends NetworkReceiverMessage
-abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging {
+abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging {
initLogging()
@@ -106,21 +108,23 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
actor ! ReportError(e.toString)
}
+
/**
* This method pushes a block (as iterator of values) into the block manager.
*/
- protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) {
+ def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) {
val buffer = new ArrayBuffer[T] ++ iterator
env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level)
- actor ! ReportBlock(blockId)
+
+ actor ! ReportBlock(blockId, metadata)
}
/**
* This method pushes a block (as bytes) into the block manager.
*/
- protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
- actor ! ReportBlock(blockId)
+ actor ! ReportBlock(blockId, metadata)
}
/** A helper actor that communicates with the NetworkInputTracker */
@@ -138,8 +142,8 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
}
override def receive() = {
- case ReportBlock(blockId) =>
- tracker ! AddBlocks(streamId, Array(blockId))
+ case ReportBlock(blockId, metadata) =>
+ tracker ! AddBlocks(streamId, Array(blockId), metadata)
case ReportError(msg) =>
tracker ! DeregisterReceiver(streamId, msg)
case StopReceiver(msg) =>
@@ -147,5 +151,6 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
tracker ! DeregisterReceiver(streamId, msg)
}
}
+
}
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
index 07ef79415d..4d9346edd8 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -13,7 +13,7 @@ import akka.dispatch._
trait NetworkInputTrackerMessage
case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
-case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage
+case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
@@ -22,7 +22,7 @@ class NetworkInputTracker(
@transient networkInputStreams: Array[NetworkInputDStream[_]])
extends Logging {
- val networkInputStreamIds = networkInputStreams.map(_.id).toArray
+ val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
val receivedBlockIds = new HashMap[Int, Queue[String]]
@@ -53,14 +53,14 @@ class NetworkInputTracker(
private class NetworkInputTrackerActor extends Actor {
def receive = {
case RegisterReceiver(streamId, receiverActor) => {
- if (!networkInputStreamIds.contains(streamId)) {
+ if (!networkInputStreamMap.contains(streamId)) {
throw new Exception("Register received for unexpected id " + streamId)
}
receiverInfo += ((streamId, receiverActor))
logInfo("Registered receiver for network stream " + streamId)
sender ! true
}
- case AddBlocks(streamId, blockIds) => {
+ case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) {
receivedBlockIds += ((streamId, new Queue[String]))
@@ -70,6 +70,7 @@ class NetworkInputTracker(
tmp.synchronized {
tmp ++= blockIds
}
+ networkInputStreamMap(streamId).addMetadata(metadata)
}
case DeregisterReceiver(streamId, msg) => {
receiverInfo -= streamId
diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
index e022b85fbe..90d8528d5b 100644
--- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
@@ -48,7 +48,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S
val buffer = queue.take()
val blockId = "input-" + streamId + "-" + nextBlockNumber
nextBlockNumber += 1
- pushBlock(blockId, buffer, storageLevel)
+ pushBlock(blockId, buffer, null, storageLevel)
}
}
}
diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
index b566200273..ff99d50b76 100644
--- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
@@ -32,7 +32,7 @@ class SocketReceiver[T: ClassManifest](
storageLevel: StorageLevel
) extends NetworkReceiver[T](streamId) {
- lazy protected val dataHandler = new DataHandler(this)
+ lazy protected val dataHandler = new DataHandler(this, storageLevel)
protected def onStart() {
logInfo("Connecting to " + host + ":" + port)
@@ -50,74 +50,6 @@ class SocketReceiver[T: ClassManifest](
dataHandler.stop()
}
- /**
- * This is a helper object that manages the data received from the socket. It divides
- * the object received into small batches of 100s of milliseconds, pushes them as
- * blocks into the block manager and reports the block IDs to the network input
- * tracker. It starts two threads, one to periodically start a new batch and prepare
- * the previous batch of as a block, the other to push the blocks into the block
- * manager.
- */
- class DataHandler(receiver: NetworkReceiver[T]) extends Serializable {
- case class Block(id: String, iterator: Iterator[T])
-
- val clock = new SystemClock()
- val blockInterval = 200L
- val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
- val blockStorageLevel = storageLevel
- val blocksForPushing = new ArrayBlockingQueue[Block](1000)
- val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
-
- var currentBuffer = new ArrayBuffer[T]
-
- def start() {
- blockIntervalTimer.start()
- blockPushingThread.start()
- logInfo("Data handler started")
- }
-
- def stop() {
- blockIntervalTimer.stop()
- blockPushingThread.interrupt()
- logInfo("Data handler stopped")
- }
-
- def += (obj: T) {
- currentBuffer += obj
- }
-
- def updateCurrentBuffer(time: Long) {
- try {
- val newBlockBuffer = currentBuffer
- currentBuffer = new ArrayBuffer[T]
- if (newBlockBuffer.size > 0) {
- val blockId = "input-" + streamId + "- " + (time - blockInterval)
- val newBlock = new Block(blockId, newBlockBuffer.toIterator)
- blocksForPushing.add(newBlock)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block interval timer thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
-
- def keepPushingBlocks() {
- logInfo("Block pushing thread started")
- try {
- while(true) {
- val block = blocksForPushing.take()
- pushBlock(block.id, block.iterator, storageLevel)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block pushing thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
- }
}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 05c83d6c08..770fd61498 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -106,9 +106,11 @@ class StreamingContext (
hostname: String,
port: Int,
groupId: String,
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
): DStream[T] = {
- val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, storageLevel)
+ val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel)
graph.addInputStream(inputStream)
inputStream
}
diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
index 3f637150d1..655f9627b3 100644
--- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -1,6 +1,6 @@
package spark.streaming.examples
-import spark.streaming.{Seconds, StreamingContext, KafkaInputDStream}
+import spark.streaming._
import spark.streaming.StreamingContext._
import spark.storage.StorageLevel
@@ -17,11 +17,20 @@ object KafkaWordCount {
// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
- val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group")
+ ssc.checkpoint("checkpoint", Time(1000 * 5))
+ val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group", Map("test" -> 1),
+ Map(KafkaPartitionKey(0, "test", "test_group", 0) -> 2382))
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
wordCounts.print()
ssc.start()
+ // Wait for 12 seconds
+ Thread.sleep(12000)
+ ssc.stop()
+
+ val newSsc = new StreamingContext("checkpoint")
+ newSsc.start()
+
}
}
diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
index 427f398237..814f2706d6 100644
--- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
@@ -1,121 +1,164 @@
package spark.streaming
+import java.lang.reflect.Method
import java.nio.ByteBuffer
import java.util.Properties
-import java.util.concurrent.{ArrayBlockingQueue, Executors}
+import java.util.concurrent.{ArrayBlockingQueue, ConcurrentHashMap, Executors}
import kafka.api.{FetchRequest}
-import kafka.consumer.{Consumer, ConsumerConfig, KafkaStream}
-import kafka.javaapi.consumer.SimpleConsumer
-import kafka.javaapi.message.ByteBufferMessageSet
+import kafka.consumer._
+import kafka.cluster.Partition
import kafka.message.{Message, MessageSet, MessageAndMetadata}
-import kafka.utils.Utils
+import kafka.serializer.StringDecoder
+import kafka.utils.{Pool, Utils, ZKGroupTopicDirs}
+import kafka.utils.ZkUtils._
+import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import spark._
import spark.RDD
import spark.storage.StorageLevel
+case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
+case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long])
+case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
+ savedOffsets: HashMap[Long, Map[KafkaPartitionKey, Long]]) extends DStreamCheckpointData(kafkaRdds)
+
/**
- * An input stream that pulls messages form a Kafka Broker.
+ * Input stream that pulls messages form a Kafka Broker.
*/
class KafkaInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
host: String,
port: Int,
groupId: String,
- storageLevel: StorageLevel,
- timeout: Int = 10000,
- bufferSize: Int = 1024000
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_ ) with Logging {
+ var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]()
+
+ override protected[streaming] def addMetadata(metadata: Any) {
+ metadata match {
+ case x : KafkaInputDStreamMetadata =>
+ savedOffsets(x.timestamp) = x.data
+ logInfo("Saved Offsets: " + savedOffsets)
+ case _ => logInfo("Received unknown metadata: " + metadata.toString)
+ }
+ }
+
+ override protected[streaming] def updateCheckpointData(currentTime: Time) {
+ super.updateCheckpointData(currentTime)
+ logInfo("Updating KafkaDStream checkpoint data: " + savedOffsets.toString)
+ checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, savedOffsets)
+ }
+
+ override protected[streaming] def restoreCheckpointData() {
+ super.restoreCheckpointData()
+ logInfo("Restoring KafkaDStream checkpoint data.")
+ checkpointData match {
+ case x : KafkaDStreamCheckpointData =>
+ savedOffsets = x.savedOffsets
+ logInfo("Restored KafkaDStream offsets: " + savedOffsets.toString)
+ }
+ }
+
def createReceiver(): NetworkReceiver[T] = {
- new KafkaReceiver(id, host, port, storageLevel, groupId, timeout).asInstanceOf[NetworkReceiver[T]]
+ new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel)
+ .asInstanceOf[NetworkReceiver[T]]
}
}
-class KafkaReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel, groupId: String, timeout: Int)
- extends NetworkReceiver[Any](streamId) {
+class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String,
+ topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) {
+
+ // Timeout for establishing a connection to Zookeper in ms.
+ val ZK_TIMEOUT = 10000
- //var executorPool : = null
- var blockPushingThread : Thread = null
+ // Handles pushing data into the BlockManager
+ lazy protected val dataHandler = new KafkaDataHandler(this, storageLevel)
+ // Keeps track of the current offsets. Maps from (topic, partitionID) -> Offset
+ lazy val offsets = HashMap[KafkaPartitionKey, Long]()
+ // Connection to Kafka
+ var consumerConnector : ZookeeperConsumerConnector = null
def onStop() {
- blockPushingThread.interrupt()
+ dataHandler.stop()
}
def onStart() {
- val executorPool = Executors.newFixedThreadPool(2)
+ // Starting the DataHandler that buffers blocks and pushes them into them BlockManager
+ dataHandler.start()
- logInfo("Starting Kafka Consumer with groupId " + groupId)
+ // In case we are using multiple Threads to handle Kafka Messages
+ val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
val zooKeeperEndPoint = host + ":" + port
+ logInfo("Starting Kafka Consumer Stream in group " + groupId)
+ logInfo("Initial offsets: " + initialOffsets.toString)
logInfo("Connecting to " + zooKeeperEndPoint)
-
- // Specify some consumer properties
+ // Specify some Consumer properties
val props = new Properties()
props.put("zk.connect", zooKeeperEndPoint)
- props.put("zk.connectiontimeout.ms", timeout.toString)
+ props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
props.put("groupid", groupId)
// Create the connection to the cluster
val consumerConfig = new ConsumerConfig(props)
- val consumerConnector = Consumer.create(consumerConfig)
- logInfo("Connected to " + zooKeeperEndPoint)
- logInfo("")
- logInfo("")
+ consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
- // Specify which topics we are listening to
- val topicCountMap = Map("test" -> 2)
- val topicMessageStreams = consumerConnector.createMessageStreams(topicCountMap)
- val streams = topicMessageStreams.get("test")
+ // Reset the Kafka offsets in case we are recovering from a failure
+ resetOffsets(initialOffsets)
- // Queue that holds the blocks
- val queue = new ArrayBlockingQueue[ByteBuffer](2)
+ logInfo("Connected to " + zooKeeperEndPoint)
- streams.getOrElse(Nil).foreach { stream =>
- executorPool.submit(new MessageHandler(stream, queue))
+ // Create Threads for each Topic/Message Stream we are listening
+ val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
+
+ topicMessageStreams.values.foreach { streams =>
+ streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
}
- blockPushingThread = new DaemonThread {
- override def run() {
- logInfo("Starting BlockPushingThread.")
- var nextBlockNumber = 0
- while (true) {
- val buffer = queue.take()
- val blockId = "input-" + streamId + "-" + nextBlockNumber
- nextBlockNumber += 1
- pushBlock(blockId, buffer, storageLevel)
- }
- }
+ }
+
+ // Overwrites the offets in Zookeper.
+ private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) {
+ offsets.foreach { case(key, offset) =>
+ val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
+ val partitionName = key.brokerId + "-" + key.partId
+ updatePersistentPath(consumerConnector.zkClient,
+ topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
}
- blockPushingThread.start()
-
- // while (true) {
- // // Create a fetch request for topic “test”, partition 0, current offset, and fetch size of 1MB
- // val fetchRequest = new FetchRequest("test", 0, offset, 1000000)
-
- // // get the message set from the consumer and print them out
- // val messages = consumer.fetch(fetchRequest)
- // for(msg <- messages.iterator) {
- // logInfo("consumed: " + Utils.toString(msg.message.payload, "UTF-8"))
- // // advance the offset after consuming each message
- // offset = msg.offset
- // queue.put(msg.message.payload)
- // }
- // }
}
- class MessageHandler(stream: KafkaStream[Message], queue: ArrayBlockingQueue[ByteBuffer]) extends Runnable {
+ // Responsible for handling Kafka Messages
+ class MessageHandler(stream: KafkaStream[String]) extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
- while(true) {
- stream.foreach { msgAndMetadata =>
- logInfo("Consumed: " + Utils.toString(msgAndMetadata.message.payload, "UTF-8"))
- queue.put(msgAndMetadata.message.payload)
- }
- }
+ stream.takeWhile { msgAndMetadata =>
+ dataHandler += msgAndMetadata.message
+
+ // Updating the offet. The key is (topic, partitionID).
+ val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic,
+ groupId, msgAndMetadata.topicInfo.partition.partId)
+ val offset = msgAndMetadata.topicInfo.getConsumeOffset
+ offsets.put(key, offset)
+ logInfo((key, offset).toString)
+
+ // Keep on handling messages
+ true
+ }
}
}
+ class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel)
+ extends DataHandler[Any](receiver, storageLevel) {
+
+ override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = {
+ new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap))
+ }
+
+ }
}