aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala38
-rw-r--r--streaming/src/main/scala/spark/streaming/DataHandler.scala83
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala9
-rw-r--r--streaming/src/main/scala/spark/streaming/RawInputDStream.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/SocketInputDStream.scala70
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala25
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala68
-rw-r--r--streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala193
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala12
10 files changed, 419 insertions, 102 deletions
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index d2e9de110e..85106b3ad8 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -37,6 +37,9 @@ import org.apache.hadoop.conf.Configuration
* - A time interval at which the DStream generates an RDD
* - A function that is used to generate an RDD after each time interval
*/
+
+case class DStreamCheckpointData(rdds: HashMap[Time, Any])
+
abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
extends Serializable with Logging {
@@ -79,7 +82,7 @@ extends Serializable with Logging {
// Checkpoint details
protected[streaming] val mustCheckpoint = false
protected[streaming] var checkpointInterval: Time = null
- protected[streaming] val checkpointData = new HashMap[Time, Any]()
+ protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]())
// Reference to whole DStream graph
protected[streaming] var graph: DStreamGraph = null
@@ -314,6 +317,15 @@ extends Serializable with Logging {
dependencies.foreach(_.forgetOldRDDs(time))
}
+ /* Adds metadata to the Stream while it is running.
+ * This methd should be overwritten by sublcasses of InputDStream.
+ */
+ protected[streaming] def addMetadata(metadata: Any) {
+ if (metadata != null) {
+ logInfo("Dropping Metadata: " + metadata.toString)
+ }
+ }
+
/**
* Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of
* this stream. This is an internal method that should not be called directly. This is
@@ -322,31 +334,29 @@ extends Serializable with Logging {
* this method to save custom checkpoint data.
*/
protected[streaming] def updateCheckpointData(currentTime: Time) {
+
logInfo("Updating checkpoint data for time " + currentTime)
// Get the checkpointed RDDs from the generated RDDs
- val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null)
+ val newRdds = generatedRDDs.filter(_._2.getCheckpointData() != null)
.map(x => (x._1, x._2.getCheckpointData()))
// Make a copy of the existing checkpoint data
- val oldCheckpointData = checkpointData.clone()
-
+ val oldRdds = checkpointData.rdds.clone()
// If the new checkpoint has checkpoints then replace existing with the new one
- if (newCheckpointData.size > 0) {
- checkpointData.clear()
- checkpointData ++= newCheckpointData
+ if (newRdds.size > 0) {
+ checkpointData.rdds.clear()
+ checkpointData.rdds ++= newRdds
}
-
// Make dependencies update their checkpoint data
dependencies.foreach(_.updateCheckpointData(currentTime))
// TODO: remove this, this is just for debugging
- newCheckpointData.foreach {
+ newRdds.foreach {
case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
}
- // If old checkpoint files have been removed from checkpoint data, then remove the files
- if (newCheckpointData.size > 0) {
- (oldCheckpointData -- newCheckpointData.keySet).foreach {
+ if (newRdds.size > 0) {
+ (oldRdds -- newRdds.keySet).foreach {
case (time, data) => {
val path = new Path(data.toString)
val fs = path.getFileSystem(new Configuration())
@@ -367,8 +377,8 @@ extends Serializable with Logging {
*/
protected[streaming] def restoreCheckpointData() {
// Create RDDs from the checkpoint data
- logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs")
- checkpointData.foreach {
+ logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs")
+ checkpointData.rdds.foreach {
case(time, data) => {
logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'")
val rdd = ssc.sc.objectFile[T](data.toString)
diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala
new file mode 100644
index 0000000000..05f307a8d1
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DataHandler.scala
@@ -0,0 +1,83 @@
+package spark.streaming
+
+import java.util.concurrent.ArrayBlockingQueue
+import scala.collection.mutable.ArrayBuffer
+import spark.Logging
+import spark.streaming.util.{RecurringTimer, SystemClock}
+import spark.storage.StorageLevel
+
+
+/**
+ * This is a helper object that manages the data received from the socket. It divides
+ * the object received into small batches of 100s of milliseconds, pushes them as
+ * blocks into the block manager and reports the block IDs to the network input
+ * tracker. It starts two threads, one to periodically start a new batch and prepare
+ * the previous batch of as a block, the other to push the blocks into the block
+ * manager.
+ */
+ class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel)
+ extends Serializable with Logging {
+
+ case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
+
+ val clock = new SystemClock()
+ val blockInterval = 200L
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockStorageLevel = storageLevel
+ val blocksForPushing = new ArrayBlockingQueue[Block](1000)
+ val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
+
+ var currentBuffer = new ArrayBuffer[T]
+
+ def createBlock(blockId: String, iterator: Iterator[T]) : Block = {
+ new Block(blockId, iterator)
+ }
+
+ def start() {
+ blockIntervalTimer.start()
+ blockPushingThread.start()
+ logInfo("Data handler started")
+ }
+
+ def stop() {
+ blockIntervalTimer.stop()
+ blockPushingThread.interrupt()
+ logInfo("Data handler stopped")
+ }
+
+ def += (obj: T) {
+ currentBuffer += obj
+ }
+
+ def updateCurrentBuffer(time: Long) {
+ try {
+ val newBlockBuffer = currentBuffer
+ currentBuffer = new ArrayBuffer[T]
+ if (newBlockBuffer.size > 0) {
+ val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval)
+ val newBlock = createBlock(blockId, newBlockBuffer.toIterator)
+ blocksForPushing.add(newBlock)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block interval timer thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+
+ def keepPushingBlocks() {
+ logInfo("Block pushing thread started")
+ try {
+ while(true) {
+ val block = blocksForPushing.take()
+ receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block pushing thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+ } \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
index f3f4c3ab13..d3f37b8b0e 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
@@ -4,9 +4,11 @@ import scala.collection.mutable.ArrayBuffer
import spark.{Logging, SparkEnv, RDD}
import spark.rdd.BlockRDD
+import spark.streaming.util.{RecurringTimer, SystemClock}
import spark.storage.StorageLevel
import java.nio.ByteBuffer
+import java.util.concurrent.ArrayBlockingQueue
import akka.actor.{Props, Actor}
import akka.pattern.ask
@@ -41,10 +43,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
sealed trait NetworkReceiverMessage
case class StopReceiver(msg: String) extends NetworkReceiverMessage
-case class ReportBlock(blockId: String) extends NetworkReceiverMessage
+case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
case class ReportError(msg: String) extends NetworkReceiverMessage
-abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging {
+abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging {
initLogging()
@@ -106,21 +108,23 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
actor ! ReportError(e.toString)
}
+
/**
* This method pushes a block (as iterator of values) into the block manager.
*/
- protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) {
+ def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) {
val buffer = new ArrayBuffer[T] ++ iterator
env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level)
- actor ! ReportBlock(blockId)
+
+ actor ! ReportBlock(blockId, metadata)
}
/**
* This method pushes a block (as bytes) into the block manager.
*/
- protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
- actor ! ReportBlock(blockId)
+ actor ! ReportBlock(blockId, metadata)
}
/** A helper actor that communicates with the NetworkInputTracker */
@@ -138,8 +142,8 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
}
override def receive() = {
- case ReportBlock(blockId) =>
- tracker ! AddBlocks(streamId, Array(blockId))
+ case ReportBlock(blockId, metadata) =>
+ tracker ! AddBlocks(streamId, Array(blockId), metadata)
case ReportError(msg) =>
tracker ! DeregisterReceiver(streamId, msg)
case StopReceiver(msg) =>
@@ -147,5 +151,6 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
tracker ! DeregisterReceiver(streamId, msg)
}
}
+
}
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
index ae6692290e..73ba877085 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -13,7 +13,7 @@ import akka.dispatch._
trait NetworkInputTrackerMessage
case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
-case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage
+case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
@@ -22,7 +22,7 @@ class NetworkInputTracker(
@transient networkInputStreams: Array[NetworkInputDStream[_]])
extends Logging {
- val networkInputStreamIds = networkInputStreams.map(_.id).toArray
+ val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
val receivedBlockIds = new HashMap[Int, Queue[String]]
@@ -54,14 +54,14 @@ class NetworkInputTracker(
private class NetworkInputTrackerActor extends Actor {
def receive = {
case RegisterReceiver(streamId, receiverActor) => {
- if (!networkInputStreamIds.contains(streamId)) {
+ if (!networkInputStreamMap.contains(streamId)) {
throw new Exception("Register received for unexpected id " + streamId)
}
receiverInfo += ((streamId, receiverActor))
logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address)
sender ! true
}
- case AddBlocks(streamId, blockIds) => {
+ case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) {
receivedBlockIds += ((streamId, new Queue[String]))
@@ -71,6 +71,7 @@ class NetworkInputTracker(
tmp.synchronized {
tmp ++= blockIds
}
+ networkInputStreamMap(streamId).addMetadata(metadata)
}
case DeregisterReceiver(streamId, msg) => {
receiverInfo -= streamId
diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
index 03726bfba6..d5db8e787d 100644
--- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
@@ -48,7 +48,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S
val buffer = queue.take()
val blockId = "input-" + streamId + "-" + nextBlockNumber
nextBlockNumber += 1
- pushBlock(blockId, buffer, storageLevel)
+ pushBlock(blockId, buffer, null, storageLevel)
}
}
}
diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
index b566200273..ff99d50b76 100644
--- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
@@ -32,7 +32,7 @@ class SocketReceiver[T: ClassManifest](
storageLevel: StorageLevel
) extends NetworkReceiver[T](streamId) {
- lazy protected val dataHandler = new DataHandler(this)
+ lazy protected val dataHandler = new DataHandler(this, storageLevel)
protected def onStart() {
logInfo("Connecting to " + host + ":" + port)
@@ -50,74 +50,6 @@ class SocketReceiver[T: ClassManifest](
dataHandler.stop()
}
- /**
- * This is a helper object that manages the data received from the socket. It divides
- * the object received into small batches of 100s of milliseconds, pushes them as
- * blocks into the block manager and reports the block IDs to the network input
- * tracker. It starts two threads, one to periodically start a new batch and prepare
- * the previous batch of as a block, the other to push the blocks into the block
- * manager.
- */
- class DataHandler(receiver: NetworkReceiver[T]) extends Serializable {
- case class Block(id: String, iterator: Iterator[T])
-
- val clock = new SystemClock()
- val blockInterval = 200L
- val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
- val blockStorageLevel = storageLevel
- val blocksForPushing = new ArrayBlockingQueue[Block](1000)
- val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
-
- var currentBuffer = new ArrayBuffer[T]
-
- def start() {
- blockIntervalTimer.start()
- blockPushingThread.start()
- logInfo("Data handler started")
- }
-
- def stop() {
- blockIntervalTimer.stop()
- blockPushingThread.interrupt()
- logInfo("Data handler stopped")
- }
-
- def += (obj: T) {
- currentBuffer += obj
- }
-
- def updateCurrentBuffer(time: Long) {
- try {
- val newBlockBuffer = currentBuffer
- currentBuffer = new ArrayBuffer[T]
- if (newBlockBuffer.size > 0) {
- val blockId = "input-" + streamId + "- " + (time - blockInterval)
- val newBlock = new Block(blockId, newBlockBuffer.toIterator)
- blocksForPushing.add(newBlock)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block interval timer thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
-
- def keepPushingBlocks() {
- logInfo("Block pushing thread started")
- try {
- while(true) {
- val block = blocksForPushing.take()
- pushBlock(block.id, block.iterator, storageLevel)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block pushing thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
- }
}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 9c19f6588d..8153dd4567 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -122,6 +122,31 @@ class StreamingContext private (
private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+ /**
+ * Create an input stream that pulls messages form a Kafka Broker.
+ *
+ * @param host Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ */
+ def kafkaStream[T: ClassManifest](
+ hostname: String,
+ port: Int,
+ groupId: String,
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
+ ): DStream[T] = {
+ val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
def networkTextStream(
hostname: String,
port: Int,
diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
new file mode 100644
index 0000000000..12e3f49fe9
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -0,0 +1,68 @@
+package spark.streaming.examples
+
+import java.util.Properties
+import kafka.message.Message
+import kafka.producer.SyncProducerConfig
+import kafka.producer._
+import spark.streaming._
+import spark.streaming.StreamingContext._
+import spark.storage.StorageLevel
+import spark.streaming.util.RawTextHelper._
+
+object KafkaWordCount {
+ def main(args: Array[String]) {
+
+ if (args.length < 6) {
+ System.err.println("Usage: KafkaWordCount <master> <hostname> <port> <group> <topics> <numThreads>")
+ System.exit(1)
+ }
+
+ val Array(master, hostname, port, group, topics, numThreads) = args
+
+ val ssc = new StreamingContext(master, "KafkaWordCount")
+ ssc.checkpoint("checkpoint")
+ ssc.setBatchDuration(Seconds(2))
+
+ val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
+ val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
+ wordCounts.print()
+
+ ssc.start()
+ }
+}
+
+// Produces some random words between 1 and 100.
+object KafkaWordCountProducer {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: KafkaWordCountProducer <hostname> <port> <topic> <messagesPerSec> <wordsPerMessage>")
+ System.exit(1)
+ }
+
+ val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args
+
+ // Zookeper connection properties
+ val props = new Properties()
+ props.put("zk.connect", hostname + ":" + port)
+ props.put("serializer.class", "kafka.serializer.StringEncoder")
+
+ val config = new ProducerConfig(props)
+ val producer = new Producer[String, String](config)
+
+ // Send some messages
+ while(true) {
+ val messages = (1 to messagesPerSec.toInt).map { messageNum =>
+ (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
+ }.toArray
+ println(messages.mkString(","))
+ val data = new ProducerData[String, String](topic, messages)
+ producer.send(data)
+ Thread.sleep(100)
+ }
+ }
+
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
new file mode 100644
index 0000000000..7c642d4802
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
@@ -0,0 +1,193 @@
+package spark.streaming
+
+import java.util.Properties
+import java.util.concurrent.Executors
+import kafka.consumer._
+import kafka.message.{Message, MessageSet, MessageAndMetadata}
+import kafka.serializer.StringDecoder
+import kafka.utils.{Utils, ZKGroupTopicDirs}
+import kafka.utils.ZkUtils._
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import spark._
+import spark.RDD
+import spark.storage.StorageLevel
+
+// Key for a specific Kafka Partition: (broker, topic, group, part)
+case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
+// NOT USED - Originally intended for fault-tolerance
+// Metadata for a Kafka Stream that it sent to the Master
+case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long])
+// NOT USED - Originally intended for fault-tolerance
+// Checkpoint data specific to a KafkaInputDstream
+case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
+ savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds)
+
+/**
+ * Input stream that pulls messages form a Kafka Broker.
+ *
+ * @param host Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ * @param storageLevel RDD storage level.
+ */
+class KafkaInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ groupId: String,
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_ ) with Logging {
+
+ // Metadata that keeps track of which messages have already been consumed.
+ var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]()
+
+ /* NOT USED - Originally intended for fault-tolerance
+
+ // In case of a failure, the offets for a particular timestamp will be restored.
+ @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null
+
+
+ override protected[streaming] def addMetadata(metadata: Any) {
+ metadata match {
+ case x : KafkaInputDStreamMetadata =>
+ savedOffsets(x.timestamp) = x.data
+ // TOOD: Remove logging
+ logInfo("New saved Offsets: " + savedOffsets)
+ case _ => logInfo("Received unknown metadata: " + metadata.toString)
+ }
+ }
+
+ override protected[streaming] def updateCheckpointData(currentTime: Time) {
+ super.updateCheckpointData(currentTime)
+ if(savedOffsets.size > 0) {
+ // Find the offets that were stored before the checkpoint was initiated
+ val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last
+ val latestOffsets = savedOffsets(key)
+ logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString)
+ checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets)
+ // TODO: This may throw out offsets that are created after the checkpoint,
+ // but it's unlikely we'll need them.
+ savedOffsets.clear()
+ }
+ }
+
+ override protected[streaming] def restoreCheckpointData() {
+ super.restoreCheckpointData()
+ logInfo("Restoring KafkaDStream checkpoint data.")
+ checkpointData match {
+ case x : KafkaDStreamCheckpointData =>
+ restoredOffsets = x.savedOffsets
+ logInfo("Restored KafkaDStream offsets: " + savedOffsets)
+ }
+ } */
+
+ def createReceiver(): NetworkReceiver[T] = {
+ new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel)
+ .asInstanceOf[NetworkReceiver[T]]
+ }
+}
+
+class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String,
+ topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) {
+
+ // Timeout for establishing a connection to Zookeper in ms.
+ val ZK_TIMEOUT = 10000
+
+ // Handles pushing data into the BlockManager
+ lazy protected val dataHandler = new DataHandler(this, storageLevel)
+ // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset
+ lazy val offsets = HashMap[KafkaPartitionKey, Long]()
+ // Connection to Kafka
+ var consumerConnector : ZookeeperConsumerConnector = null
+
+ def onStop() {
+ dataHandler.stop()
+ }
+
+ def onStart() {
+
+ // Starting the DataHandler that buffers blocks and pushes them into them BlockManager
+ dataHandler.start()
+
+ // In case we are using multiple Threads to handle Kafka Messages
+ val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
+
+ val zooKeeperEndPoint = host + ":" + port
+ logInfo("Starting Kafka Consumer Stream with group: " + groupId)
+ logInfo("Initial offsets: " + initialOffsets.toString)
+
+ // Zookeper connection properties
+ val props = new Properties()
+ props.put("zk.connect", zooKeeperEndPoint)
+ props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
+ props.put("groupid", groupId)
+
+ // Create the connection to the cluster
+ logInfo("Connecting to Zookeper: " + zooKeeperEndPoint)
+ val consumerConfig = new ConsumerConfig(props)
+ consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
+ logInfo("Connected to " + zooKeeperEndPoint)
+
+ // Reset the Kafka offsets in case we are recovering from a failure
+ resetOffsets(initialOffsets)
+
+ // Create Threads for each Topic/Message Stream we are listening
+ val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
+
+ // Start the messages handler for each partition
+ topicMessageStreams.values.foreach { streams =>
+ streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
+ }
+
+ }
+
+ // Overwrites the offets in Zookeper.
+ private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) {
+ offsets.foreach { case(key, offset) =>
+ val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
+ val partitionName = key.brokerId + "-" + key.partId
+ updatePersistentPath(consumerConnector.zkClient,
+ topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
+ }
+ }
+
+ // Handles Kafka Messages
+ private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
+ def run() {
+ logInfo("Starting MessageHandler.")
+ stream.takeWhile { msgAndMetadata =>
+ dataHandler += msgAndMetadata.message
+
+ // Updating the offet. The key is (broker, topic, group, partition).
+ val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic,
+ groupId, msgAndMetadata.topicInfo.partition.partId)
+ val offset = msgAndMetadata.topicInfo.getConsumeOffset
+ offsets.put(key, offset)
+ // logInfo("Handled message: " + (key, offset).toString)
+
+ // Keep on handling messages
+ true
+ }
+ }
+ }
+
+ // NOT USED - Originally intended for fault-tolerance
+ // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel)
+ // extends DataHandler[Any](receiver, storageLevel) {
+
+ // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = {
+ // // Creates a new Block with Kafka-specific Metadata
+ // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap))
+ // }
+
+ // }
+
+}
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index b3afedf39f..0d82b2f1ea 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -63,9 +63,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// then check whether some RDD has been checkpointed or not
ssc.start()
runStreamsWithRealDelay(ssc, firstNumBatches)
- logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]")
- assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure")
- stateStream.checkpointData.foreach {
+ logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]")
+ assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure")
+ stateStream.checkpointData.rdds.foreach {
case (time, data) => {
val file = new File(data.toString)
assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist")
@@ -74,7 +74,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// Run till a further time such that previous checkpoint files in the stream would be deleted
// and check whether the earlier checkpoint files are deleted
- val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString))
+ val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString))
runStreamsWithRealDelay(ssc, secondNumBatches)
checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
ssc.stop()
@@ -91,8 +91,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// is present in the checkpoint data or not
ssc.start()
runStreamsWithRealDelay(ssc, 1)
- assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure")
- stateStream.checkpointData.foreach {
+ assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure")
+ stateStream.checkpointData.rdds.foreach {
case (time, data) => {
val file = new File(data.toString)
assert(file.exists(),