aboutsummaryrefslogtreecommitdiff
path: root/streaming/src
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-12-20 14:24:19 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2012-12-20 14:24:19 -0800
commit8512dd3225a3ce9c38608a319d7f85fdf75798b4 (patch)
tree088fa6bcdca8b5970131d69ab2a2bc5ccd679cf9 /streaming/src
parentfe777eb77dee3c5bc5a7a332098d27f517ad3fe4 (diff)
parent2a87d816a24c62215d682e3a7af65489c0d6e708 (diff)
downloadspark-8512dd3225a3ce9c38608a319d7f85fdf75798b4.tar.gz
spark-8512dd3225a3ce9c38608a319d7f85fdf75798b4.tar.bz2
spark-8512dd3225a3ce9c38608a319d7f85fdf75798b4.zip
Merge branch 'dev' of github.com:radlab/spark into dev-checkpoint
Conflicts: core/src/main/scala/spark/ParallelCollection.scala core/src/test/scala/spark/CheckpointSuite.scala streaming/src/main/scala/spark/streaming/DStream.scala
Diffstat (limited to 'streaming/src')
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala47
-rw-r--r--streaming/src/main/scala/spark/streaming/DataHandler.scala83
-rw-r--r--streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala130
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala23
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala22
-rw-r--r--streaming/src/main/scala/spark/streaming/RawInputDStream.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/SocketInputDStream.scala72
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala43
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala69
-rw-r--r--streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala193
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala12
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala59
13 files changed, 683 insertions, 110 deletions
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 69fefa21a0..d5048aeed7 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -37,6 +37,9 @@ import org.apache.hadoop.conf.Configuration
* - A time interval at which the DStream generates an RDD
* - A function that is used to generate an RDD after each time interval
*/
+
+case class DStreamCheckpointData(rdds: HashMap[Time, Any])
+
abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
extends Serializable with Logging {
@@ -79,7 +82,7 @@ extends Serializable with Logging {
// Checkpoint details
protected[streaming] val mustCheckpoint = false
protected[streaming] var checkpointInterval: Time = null
- protected[streaming] val checkpointData = new HashMap[Time, Any]()
+ protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]())
// Reference to whole DStream graph
protected[streaming] var graph: DStreamGraph = null
@@ -314,6 +317,15 @@ extends Serializable with Logging {
dependencies.foreach(_.forgetOldRDDs(time))
}
+ /* Adds metadata to the Stream while it is running.
+ * This methd should be overwritten by sublcasses of InputDStream.
+ */
+ protected[streaming] def addMetadata(metadata: Any) {
+ if (metadata != null) {
+ logInfo("Dropping Metadata: " + metadata.toString)
+ }
+ }
+
/**
* Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of
* this stream. This is an internal method that should not be called directly. This is
@@ -325,29 +337,28 @@ extends Serializable with Logging {
logInfo("Updating checkpoint data for time " + currentTime)
// Get the checkpointed RDDs from the generated RDDs
+ val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+ .map(x => (x._1, x._2.getCheckpointFile.get))
- val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
- .map(x => (x._1, x._2.getCheckpointFile.get))
- // Make a copy of the existing checkpoint data
- val oldCheckpointData = checkpointData.clone()
+ // Make a copy of the existing checkpoint data (checkpointed RDDs)
+ val oldRdds = checkpointData.rdds.clone()
- // If the new checkpoint has checkpoints then replace existing with the new one
- if (newCheckpointData.size > 0) {
- checkpointData.clear()
- checkpointData ++= newCheckpointData
+ // If the new checkpoint data has checkpoints then replace existing with the new one
+ if (newRdds.size > 0) {
+ checkpointData.rdds.clear()
+ checkpointData.rdds ++= newRdds
}
- // Make dependencies update their checkpoint data
+ // Make parent DStreams update their checkpoint data
dependencies.foreach(_.updateCheckpointData(currentTime))
// TODO: remove this, this is just for debugging
- newCheckpointData.foreach {
+ newRdds.foreach {
case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
}
- // If old checkpoint files have been removed from checkpoint data, then remove the files
- if (newCheckpointData.size > 0) {
- (oldCheckpointData -- newCheckpointData.keySet).foreach {
+ if (newRdds.size > 0) {
+ (oldRdds -- newRdds.keySet).foreach {
case (time, data) => {
val path = new Path(data.toString)
val fs = path.getFileSystem(new Configuration())
@@ -356,8 +367,8 @@ extends Serializable with Logging {
}
}
}
- logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.size + " checkpoints, "
- + "[" + checkpointData.mkString(",") + "]")
+ logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.rdds.size + " checkpoints, "
+ + "[" + checkpointData.rdds.mkString(",") + "]")
}
/**
@@ -368,8 +379,8 @@ extends Serializable with Logging {
*/
protected[streaming] def restoreCheckpointData() {
// Create RDDs from the checkpoint data
- logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs")
- checkpointData.foreach {
+ logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs")
+ checkpointData.rdds.foreach {
case(time, data) => {
logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'")
val rdd = ssc.sc.checkpointFile[T](data.toString)
diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala
new file mode 100644
index 0000000000..05f307a8d1
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DataHandler.scala
@@ -0,0 +1,83 @@
+package spark.streaming
+
+import java.util.concurrent.ArrayBlockingQueue
+import scala.collection.mutable.ArrayBuffer
+import spark.Logging
+import spark.streaming.util.{RecurringTimer, SystemClock}
+import spark.storage.StorageLevel
+
+
+/**
+ * This is a helper object that manages the data received from the socket. It divides
+ * the object received into small batches of 100s of milliseconds, pushes them as
+ * blocks into the block manager and reports the block IDs to the network input
+ * tracker. It starts two threads, one to periodically start a new batch and prepare
+ * the previous batch of as a block, the other to push the blocks into the block
+ * manager.
+ */
+ class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel)
+ extends Serializable with Logging {
+
+ case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
+
+ val clock = new SystemClock()
+ val blockInterval = 200L
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockStorageLevel = storageLevel
+ val blocksForPushing = new ArrayBlockingQueue[Block](1000)
+ val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
+
+ var currentBuffer = new ArrayBuffer[T]
+
+ def createBlock(blockId: String, iterator: Iterator[T]) : Block = {
+ new Block(blockId, iterator)
+ }
+
+ def start() {
+ blockIntervalTimer.start()
+ blockPushingThread.start()
+ logInfo("Data handler started")
+ }
+
+ def stop() {
+ blockIntervalTimer.stop()
+ blockPushingThread.interrupt()
+ logInfo("Data handler stopped")
+ }
+
+ def += (obj: T) {
+ currentBuffer += obj
+ }
+
+ def updateCurrentBuffer(time: Long) {
+ try {
+ val newBlockBuffer = currentBuffer
+ currentBuffer = new ArrayBuffer[T]
+ if (newBlockBuffer.size > 0) {
+ val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval)
+ val newBlock = createBlock(blockId, newBlockBuffer.toIterator)
+ blocksForPushing.add(newBlock)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block interval timer thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+
+ def keepPushingBlocks() {
+ logInfo("Block pushing thread started")
+ try {
+ while(true) {
+ val block = blocksForPushing.take()
+ receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block pushing thread interrupted")
+ case e: Exception =>
+ receiver.stop()
+ }
+ }
+ } \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala
new file mode 100644
index 0000000000..2959ce4540
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala
@@ -0,0 +1,130 @@
+package spark.streaming
+
+import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import spark.storage.StorageLevel
+import org.apache.flume.source.avro.AvroSourceProtocol
+import org.apache.flume.source.avro.AvroFlumeEvent
+import org.apache.flume.source.avro.Status
+import org.apache.avro.ipc.specific.SpecificResponder
+import org.apache.avro.ipc.NettyServer
+import java.net.InetSocketAddress
+import collection.JavaConversions._
+import spark.Utils
+import java.nio.ByteBuffer
+
+class FlumeInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ storageLevel: StorageLevel
+) extends NetworkInputDStream[SparkFlumeEvent](ssc_) {
+
+ override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = {
+ new FlumeReceiver(id, host, port, storageLevel)
+ }
+}
+
+/**
+ * A wrapper class for AvroFlumeEvent's with a custom serialization format.
+ *
+ * This is necessary because AvroFlumeEvent uses inner data structures
+ * which are not serializable.
+ */
+class SparkFlumeEvent() extends Externalizable {
+ var event : AvroFlumeEvent = new AvroFlumeEvent()
+
+ /* De-serialize from bytes. */
+ def readExternal(in: ObjectInput) {
+ val bodyLength = in.readInt()
+ val bodyBuff = new Array[Byte](bodyLength)
+ in.read(bodyBuff)
+
+ val numHeaders = in.readInt()
+ val headers = new java.util.HashMap[CharSequence, CharSequence]
+
+ for (i <- 0 until numHeaders) {
+ val keyLength = in.readInt()
+ val keyBuff = new Array[Byte](keyLength)
+ in.read(keyBuff)
+ val key : String = Utils.deserialize(keyBuff)
+
+ val valLength = in.readInt()
+ val valBuff = new Array[Byte](valLength)
+ in.read(valBuff)
+ val value : String = Utils.deserialize(valBuff)
+
+ headers.put(key, value)
+ }
+
+ event.setBody(ByteBuffer.wrap(bodyBuff))
+ event.setHeaders(headers)
+ }
+
+ /* Serialize to bytes. */
+ def writeExternal(out: ObjectOutput) {
+ val body = event.getBody.array()
+ out.writeInt(body.length)
+ out.write(body)
+
+ val numHeaders = event.getHeaders.size()
+ out.writeInt(numHeaders)
+ for ((k, v) <- event.getHeaders) {
+ val keyBuff = Utils.serialize(k.toString)
+ out.writeInt(keyBuff.length)
+ out.write(keyBuff)
+ val valBuff = Utils.serialize(v.toString)
+ out.writeInt(valBuff.length)
+ out.write(valBuff)
+ }
+ }
+}
+
+object SparkFlumeEvent {
+ def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = {
+ val event = new SparkFlumeEvent
+ event.event = in
+ event
+ }
+}
+
+/** A simple server that implements Flume's Avro protocol. */
+class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol {
+ override def append(event : AvroFlumeEvent) : Status = {
+ receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)
+ Status.OK
+ }
+
+ override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = {
+ events.foreach (event =>
+ receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event))
+ Status.OK
+ }
+}
+
+/** A NetworkReceiver which listens for events using the
+ * Flume Avro interface.*/
+class FlumeReceiver(
+ streamId: Int,
+ host: String,
+ port: Int,
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[SparkFlumeEvent](streamId) {
+
+ lazy val dataHandler = new DataHandler(this, storageLevel)
+
+ protected override def onStart() {
+ val responder = new SpecificResponder(
+ classOf[AvroSourceProtocol], new FlumeEventServer(this));
+ val server = new NettyServer(responder, new InetSocketAddress(host, port));
+ dataHandler.start()
+ server.start()
+ logInfo("Flume receiver started")
+ }
+
+ protected override def onStop() {
+ dataHandler.stop()
+ logInfo("Flume receiver stopped")
+ }
+
+ override def getLocationPreference = Some(host)
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
index f3f4c3ab13..4e4e9fc942 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
@@ -4,6 +4,7 @@ import scala.collection.mutable.ArrayBuffer
import spark.{Logging, SparkEnv, RDD}
import spark.rdd.BlockRDD
+import spark.streaming.util.{RecurringTimer, SystemClock}
import spark.storage.StorageLevel
import java.nio.ByteBuffer
@@ -41,10 +42,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
sealed trait NetworkReceiverMessage
case class StopReceiver(msg: String) extends NetworkReceiverMessage
-case class ReportBlock(blockId: String) extends NetworkReceiverMessage
+case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
case class ReportError(msg: String) extends NetworkReceiverMessage
-abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging {
+abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging {
initLogging()
@@ -61,6 +62,9 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
/** This method will be called to stop receiving data. */
protected def onStop()
+ /** This method conveys a placement preference (hostname) for this receiver. */
+ def getLocationPreference() : Option[String] = None
+
/**
* This method starts the receiver. First is accesses all the lazy members to
* materialize them. Then it calls the user-defined onStart() method to start
@@ -106,21 +110,23 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
actor ! ReportError(e.toString)
}
+
/**
* This method pushes a block (as iterator of values) into the block manager.
*/
- protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) {
+ def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) {
val buffer = new ArrayBuffer[T] ++ iterator
env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level)
- actor ! ReportBlock(blockId)
+
+ actor ! ReportBlock(blockId, metadata)
}
/**
* This method pushes a block (as bytes) into the block manager.
*/
- protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
- actor ! ReportBlock(blockId)
+ actor ! ReportBlock(blockId, metadata)
}
/** A helper actor that communicates with the NetworkInputTracker */
@@ -138,8 +144,8 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
}
override def receive() = {
- case ReportBlock(blockId) =>
- tracker ! AddBlocks(streamId, Array(blockId))
+ case ReportBlock(blockId, metadata) =>
+ tracker ! AddBlocks(streamId, Array(blockId), metadata)
case ReportError(msg) =>
tracker ! DeregisterReceiver(streamId, msg)
case StopReceiver(msg) =>
@@ -148,4 +154,3 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ
}
}
}
-
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
index ae6692290e..b421f795ee 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -13,7 +13,7 @@ import akka.dispatch._
trait NetworkInputTrackerMessage
case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
-case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage
+case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
@@ -22,7 +22,7 @@ class NetworkInputTracker(
@transient networkInputStreams: Array[NetworkInputDStream[_]])
extends Logging {
- val networkInputStreamIds = networkInputStreams.map(_.id).toArray
+ val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
val receivedBlockIds = new HashMap[Int, Queue[String]]
@@ -54,14 +54,14 @@ class NetworkInputTracker(
private class NetworkInputTrackerActor extends Actor {
def receive = {
case RegisterReceiver(streamId, receiverActor) => {
- if (!networkInputStreamIds.contains(streamId)) {
+ if (!networkInputStreamMap.contains(streamId)) {
throw new Exception("Register received for unexpected id " + streamId)
}
receiverInfo += ((streamId, receiverActor))
logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address)
sender ! true
}
- case AddBlocks(streamId, blockIds) => {
+ case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) {
receivedBlockIds += ((streamId, new Queue[String]))
@@ -71,6 +71,7 @@ class NetworkInputTracker(
tmp.synchronized {
tmp ++= blockIds
}
+ networkInputStreamMap(streamId).addMetadata(metadata)
}
case DeregisterReceiver(streamId, msg) => {
receiverInfo -= streamId
@@ -97,7 +98,18 @@ class NetworkInputTracker(
def startReceivers() {
val receivers = networkInputStreams.map(_.createReceiver())
- val tempRDD = ssc.sc.makeRDD(receivers, receivers.size)
+
+ // Right now, we only honor preferences if all receivers have them
+ val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _)
+
+ val tempRDD =
+ if (hasLocationPreferences) {
+ val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString)))
+ ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences)
+ }
+ else {
+ ssc.sc.makeRDD(receivers, receivers.size)
+ }
val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => {
if (!iterator.hasNext) {
diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
index 03726bfba6..6acaa9aab1 100644
--- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
@@ -31,6 +31,8 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S
var blockPushingThread: Thread = null
+ override def getLocationPreference = None
+
def onStart() {
// Open a socket to the target address and keep reading from it
logInfo("Connecting to " + host + ":" + port)
@@ -48,7 +50,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S
val buffer = queue.take()
val blockId = "input-" + streamId + "-" + nextBlockNumber
nextBlockNumber += 1
- pushBlock(blockId, buffer, storageLevel)
+ pushBlock(blockId, buffer, null, storageLevel)
}
}
}
diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
index b566200273..a9e37c0ff0 100644
--- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
@@ -32,7 +32,9 @@ class SocketReceiver[T: ClassManifest](
storageLevel: StorageLevel
) extends NetworkReceiver[T](streamId) {
- lazy protected val dataHandler = new DataHandler(this)
+ lazy protected val dataHandler = new DataHandler(this, storageLevel)
+
+ override def getLocationPreference = None
protected def onStart() {
logInfo("Connecting to " + host + ":" + port)
@@ -50,74 +52,6 @@ class SocketReceiver[T: ClassManifest](
dataHandler.stop()
}
- /**
- * This is a helper object that manages the data received from the socket. It divides
- * the object received into small batches of 100s of milliseconds, pushes them as
- * blocks into the block manager and reports the block IDs to the network input
- * tracker. It starts two threads, one to periodically start a new batch and prepare
- * the previous batch of as a block, the other to push the blocks into the block
- * manager.
- */
- class DataHandler(receiver: NetworkReceiver[T]) extends Serializable {
- case class Block(id: String, iterator: Iterator[T])
-
- val clock = new SystemClock()
- val blockInterval = 200L
- val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
- val blockStorageLevel = storageLevel
- val blocksForPushing = new ArrayBlockingQueue[Block](1000)
- val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
-
- var currentBuffer = new ArrayBuffer[T]
-
- def start() {
- blockIntervalTimer.start()
- blockPushingThread.start()
- logInfo("Data handler started")
- }
-
- def stop() {
- blockIntervalTimer.stop()
- blockPushingThread.interrupt()
- logInfo("Data handler stopped")
- }
-
- def += (obj: T) {
- currentBuffer += obj
- }
-
- def updateCurrentBuffer(time: Long) {
- try {
- val newBlockBuffer = currentBuffer
- currentBuffer = new ArrayBuffer[T]
- if (newBlockBuffer.size > 0) {
- val blockId = "input-" + streamId + "- " + (time - blockInterval)
- val newBlock = new Block(blockId, newBlockBuffer.toIterator)
- blocksForPushing.add(newBlock)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block interval timer thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
-
- def keepPushingBlocks() {
- logInfo("Block pushing thread started")
- try {
- while(true) {
- val block = blocksForPushing.take()
- pushBlock(block.id, block.iterator, storageLevel)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block pushing thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
- }
}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 9c19f6588d..ce47bcb2da 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -15,6 +15,7 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
+import org.apache.flume.source.avro.AvroFlumeEvent
import org.apache.hadoop.fs.Path
import java.util.UUID
import spark.util.MetadataCleaner
@@ -122,6 +123,31 @@ class StreamingContext private (
private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+ /**
+ * Create an input stream that pulls messages form a Kafka Broker.
+ *
+ * @param host Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ */
+ def kafkaStream[T: ClassManifest](
+ hostname: String,
+ port: Int,
+ groupId: String,
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
+ ): DStream[T] = {
+ val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
def networkTextStream(
hostname: String,
port: Int,
@@ -141,6 +167,16 @@ class StreamingContext private (
inputStream
}
+ def flumeStream (
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2): DStream[SparkFlumeEvent] = {
+ val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
+
def rawNetworkStream[T: ClassManifest](
hostname: String,
port: Int,
diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
new file mode 100644
index 0000000000..e60ce483a3
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
@@ -0,0 +1,43 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.storage.StorageLevel
+import spark.streaming._
+
+/**
+ * Produce a streaming count of events received from Flume.
+ *
+ * This should be used in conjunction with an AvroSink in Flume. It will start
+ * an Avro server on at the request host:port address and listen for requests.
+ * Your Flume AvroSink should be pointed to this address.
+ *
+ * Usage: FlumeEventCount <master> <host> <port>
+ *
+ * <master> is a Spark master URL
+ * <host> is the host the Flume receiver will be started on - a receiver
+ * creates a server and listens for flume events.
+ * <port> is the port the Flume receiver will listen on.
+ */
+object FlumeEventCount {
+ def main(args: Array[String]) {
+ if (args.length != 3) {
+ System.err.println(
+ "Usage: FlumeEventCount <master> <host> <port>")
+ System.exit(1)
+ }
+
+ val Array(master, host, IntParam(port)) = args
+
+ val batchInterval = Milliseconds(2000)
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval)
+
+ // Create a flume stream
+ val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY)
+
+ // Print out the count of events received from this server in each batch
+ stream.count().map(cnt => "Received " + cnt + " flume events." ).print()
+
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
new file mode 100644
index 0000000000..fe55db6e2c
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -0,0 +1,69 @@
+package spark.streaming.examples
+
+import java.util.Properties
+import kafka.message.Message
+import kafka.producer.SyncProducerConfig
+import kafka.producer._
+import spark.SparkContext
+import spark.streaming._
+import spark.streaming.StreamingContext._
+import spark.storage.StorageLevel
+import spark.streaming.util.RawTextHelper._
+
+object KafkaWordCount {
+ def main(args: Array[String]) {
+
+ if (args.length < 6) {
+ System.err.println("Usage: KafkaWordCount <master> <hostname> <port> <group> <topics> <numThreads>")
+ System.exit(1)
+ }
+
+ val Array(master, hostname, port, group, topics, numThreads) = args
+
+ val sc = new SparkContext(master, "KafkaWordCount")
+ val ssc = new StreamingContext(sc, Seconds(2))
+ ssc.checkpoint("checkpoint")
+
+ val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
+ val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
+ wordCounts.print()
+
+ ssc.start()
+ }
+}
+
+// Produces some random words between 1 and 100.
+object KafkaWordCountProducer {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: KafkaWordCountProducer <hostname> <port> <topic> <messagesPerSec> <wordsPerMessage>")
+ System.exit(1)
+ }
+
+ val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args
+
+ // Zookeper connection properties
+ val props = new Properties()
+ props.put("zk.connect", hostname + ":" + port)
+ props.put("serializer.class", "kafka.serializer.StringEncoder")
+
+ val config = new ProducerConfig(props)
+ val producer = new Producer[String, String](config)
+
+ // Send some messages
+ while(true) {
+ val messages = (1 to messagesPerSec.toInt).map { messageNum =>
+ (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
+ }.toArray
+ println(messages.mkString(","))
+ val data = new ProducerData[String, String](topic, messages)
+ producer.send(data)
+ Thread.sleep(100)
+ }
+ }
+
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
new file mode 100644
index 0000000000..7c642d4802
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
@@ -0,0 +1,193 @@
+package spark.streaming
+
+import java.util.Properties
+import java.util.concurrent.Executors
+import kafka.consumer._
+import kafka.message.{Message, MessageSet, MessageAndMetadata}
+import kafka.serializer.StringDecoder
+import kafka.utils.{Utils, ZKGroupTopicDirs}
+import kafka.utils.ZkUtils._
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import spark._
+import spark.RDD
+import spark.storage.StorageLevel
+
+// Key for a specific Kafka Partition: (broker, topic, group, part)
+case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
+// NOT USED - Originally intended for fault-tolerance
+// Metadata for a Kafka Stream that it sent to the Master
+case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long])
+// NOT USED - Originally intended for fault-tolerance
+// Checkpoint data specific to a KafkaInputDstream
+case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
+ savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds)
+
+/**
+ * Input stream that pulls messages form a Kafka Broker.
+ *
+ * @param host Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ * @param storageLevel RDD storage level.
+ */
+class KafkaInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ groupId: String,
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_ ) with Logging {
+
+ // Metadata that keeps track of which messages have already been consumed.
+ var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]()
+
+ /* NOT USED - Originally intended for fault-tolerance
+
+ // In case of a failure, the offets for a particular timestamp will be restored.
+ @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null
+
+
+ override protected[streaming] def addMetadata(metadata: Any) {
+ metadata match {
+ case x : KafkaInputDStreamMetadata =>
+ savedOffsets(x.timestamp) = x.data
+ // TOOD: Remove logging
+ logInfo("New saved Offsets: " + savedOffsets)
+ case _ => logInfo("Received unknown metadata: " + metadata.toString)
+ }
+ }
+
+ override protected[streaming] def updateCheckpointData(currentTime: Time) {
+ super.updateCheckpointData(currentTime)
+ if(savedOffsets.size > 0) {
+ // Find the offets that were stored before the checkpoint was initiated
+ val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last
+ val latestOffsets = savedOffsets(key)
+ logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString)
+ checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets)
+ // TODO: This may throw out offsets that are created after the checkpoint,
+ // but it's unlikely we'll need them.
+ savedOffsets.clear()
+ }
+ }
+
+ override protected[streaming] def restoreCheckpointData() {
+ super.restoreCheckpointData()
+ logInfo("Restoring KafkaDStream checkpoint data.")
+ checkpointData match {
+ case x : KafkaDStreamCheckpointData =>
+ restoredOffsets = x.savedOffsets
+ logInfo("Restored KafkaDStream offsets: " + savedOffsets)
+ }
+ } */
+
+ def createReceiver(): NetworkReceiver[T] = {
+ new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel)
+ .asInstanceOf[NetworkReceiver[T]]
+ }
+}
+
+class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String,
+ topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) {
+
+ // Timeout for establishing a connection to Zookeper in ms.
+ val ZK_TIMEOUT = 10000
+
+ // Handles pushing data into the BlockManager
+ lazy protected val dataHandler = new DataHandler(this, storageLevel)
+ // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset
+ lazy val offsets = HashMap[KafkaPartitionKey, Long]()
+ // Connection to Kafka
+ var consumerConnector : ZookeeperConsumerConnector = null
+
+ def onStop() {
+ dataHandler.stop()
+ }
+
+ def onStart() {
+
+ // Starting the DataHandler that buffers blocks and pushes them into them BlockManager
+ dataHandler.start()
+
+ // In case we are using multiple Threads to handle Kafka Messages
+ val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
+
+ val zooKeeperEndPoint = host + ":" + port
+ logInfo("Starting Kafka Consumer Stream with group: " + groupId)
+ logInfo("Initial offsets: " + initialOffsets.toString)
+
+ // Zookeper connection properties
+ val props = new Properties()
+ props.put("zk.connect", zooKeeperEndPoint)
+ props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
+ props.put("groupid", groupId)
+
+ // Create the connection to the cluster
+ logInfo("Connecting to Zookeper: " + zooKeeperEndPoint)
+ val consumerConfig = new ConsumerConfig(props)
+ consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
+ logInfo("Connected to " + zooKeeperEndPoint)
+
+ // Reset the Kafka offsets in case we are recovering from a failure
+ resetOffsets(initialOffsets)
+
+ // Create Threads for each Topic/Message Stream we are listening
+ val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
+
+ // Start the messages handler for each partition
+ topicMessageStreams.values.foreach { streams =>
+ streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
+ }
+
+ }
+
+ // Overwrites the offets in Zookeper.
+ private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) {
+ offsets.foreach { case(key, offset) =>
+ val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
+ val partitionName = key.brokerId + "-" + key.partId
+ updatePersistentPath(consumerConnector.zkClient,
+ topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
+ }
+ }
+
+ // Handles Kafka Messages
+ private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
+ def run() {
+ logInfo("Starting MessageHandler.")
+ stream.takeWhile { msgAndMetadata =>
+ dataHandler += msgAndMetadata.message
+
+ // Updating the offet. The key is (broker, topic, group, partition).
+ val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic,
+ groupId, msgAndMetadata.topicInfo.partition.partId)
+ val offset = msgAndMetadata.topicInfo.getConsumeOffset
+ offsets.put(key, offset)
+ // logInfo("Handled message: " + (key, offset).toString)
+
+ // Keep on handling messages
+ true
+ }
+ }
+ }
+
+ // NOT USED - Originally intended for fault-tolerance
+ // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel)
+ // extends DataHandler[Any](receiver, storageLevel) {
+
+ // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = {
+ // // Creates a new Block with Kafka-specific Metadata
+ // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap))
+ // }
+
+ // }
+
+}
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index b3afedf39f..0d82b2f1ea 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -63,9 +63,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// then check whether some RDD has been checkpointed or not
ssc.start()
runStreamsWithRealDelay(ssc, firstNumBatches)
- logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]")
- assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure")
- stateStream.checkpointData.foreach {
+ logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]")
+ assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure")
+ stateStream.checkpointData.rdds.foreach {
case (time, data) => {
val file = new File(data.toString)
assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist")
@@ -74,7 +74,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// Run till a further time such that previous checkpoint files in the stream would be deleted
// and check whether the earlier checkpoint files are deleted
- val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString))
+ val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString))
runStreamsWithRealDelay(ssc, secondNumBatches)
checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
ssc.stop()
@@ -91,8 +91,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// is present in the checkpoint data or not
ssc.start()
runStreamsWithRealDelay(ssc, 1)
- assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure")
- stateStream.checkpointData.foreach {
+ assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure")
+ stateStream.checkpointData.rdds.foreach {
case (time, data) => {
val file = new File(data.toString)
assert(file.exists(),
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index e98c096725..ed9a659092 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -1,6 +1,6 @@
package spark.streaming
-import java.net.{SocketException, Socket, ServerSocket}
+import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter}
import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
@@ -10,7 +10,14 @@ import spark.Logging
import scala.util.Random
import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter
-
+import org.apache.flume.source.avro.AvroSourceProtocol
+import org.apache.flume.source.avro.AvroFlumeEvent
+import org.apache.flume.source.avro.Status
+import org.apache.avro.ipc.{specific, NettyTransceiver}
+import org.apache.avro.ipc.specific.SpecificRequestor
+import java.nio.ByteBuffer
+import collection.JavaConversions._
+import java.nio.charset.Charset
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
@@ -123,6 +130,54 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
ssc.stop()
}
+ test("flume input stream") {
+ // Set up the streaming context and input streams
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val flumeStream = ssc.flumeStream("localhost", 33333, StorageLevel.MEMORY_AND_DISK)
+ val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]]
+ with SynchronizedBuffer[Seq[SparkFlumeEvent]]
+ val outputStream = new TestOutputStream(flumeStream, outputBuffer)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3, 4, 5)
+
+ val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", 33333));
+ val client = SpecificRequestor.getClient(
+ classOf[AvroSourceProtocol], transceiver);
+
+ for (i <- 0 until input.size) {
+ val event = new AvroFlumeEvent
+ event.setBody(ByteBuffer.wrap(input(i).toString.getBytes()))
+ event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header"))
+ client.append(event)
+ Thread.sleep(500)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+
+ val startTime = System.currentTimeMillis()
+ while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size)
+ Thread.sleep(100)
+ }
+ Thread.sleep(1000)
+ val timeTaken = System.currentTimeMillis() - startTime
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+ logInfo("Stopping context")
+ ssc.stop()
+
+ val decoder = Charset.forName("UTF-8").newDecoder()
+
+ assert(outputBuffer.size === input.length)
+ for (i <- 0 until outputBuffer.size) {
+ assert(outputBuffer(i).size === 1)
+ val str = decoder.decode(outputBuffer(i).head.event.getBody)
+ assert(str.toString === input(i).toString)
+ assert(outputBuffer(i).head.event.getHeaders.get("test") === "header")
+ }
+ }
+
test("file input stream") {
// Create a temporary directory