aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala10
-rw-r--r--external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala59
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala4
3 files changed, 67 insertions, 6 deletions
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index 1b1fc8051d..6715aede79 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -17,7 +17,6 @@
package org.apache.spark.streaming.kafka
-
import scala.annotation.tailrec
import scala.collection.mutable
import scala.reflect.{classTag, ClassTag}
@@ -27,10 +26,10 @@ import kafka.message.MessageAndMetadata
import kafka.serializer.Decoder
import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
import org.apache.spark.streaming.{StreamingContext, Time}
import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
+import org.apache.spark.streaming.scheduler.InputInfo
/**
* A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where
@@ -117,6 +116,11 @@ class DirectKafkaInputDStream[
val rdd = KafkaRDD[K, V, U, T, R](
context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler)
+ // Report the record number of this batch interval to InputInfoTracker.
+ val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum
+ val inputInfo = InputInfo(id, numRecords)
+ ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
+
currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset)
Some(rdd)
}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index 415730f555..b6d314dfc7 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.kafka
import java.io.File
+import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -34,6 +35,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.Utils
class DirectKafkaStreamSuite
@@ -290,7 +292,6 @@ class DirectKafkaStreamSuite
},
"Recovered ranges are not the same as the ones generated"
)
-
// Restart context, give more data and verify the total at the end
// If the total is write that means each records has been received only once
ssc.start()
@@ -301,6 +302,44 @@ class DirectKafkaStreamSuite
ssc.stop()
}
+ test("Direct Kafka stream report input information") {
+ val topic = "report-test"
+ val data = Map("a" -> 7, "b" -> 9)
+ kafkaTestUtils.createTopic(topic)
+ kafkaTestUtils.sendMessages(topic, data)
+
+ val totalSent = data.values.sum
+ val kafkaParams = Map(
+ "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
+ "auto.offset.reset" -> "smallest"
+ )
+
+ import DirectKafkaStreamSuite._
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val collector = new InputInfoCollector
+ ssc.addStreamingListener(collector)
+
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, Set(topic))
+ }
+
+ val allReceived = new ArrayBuffer[(String, String)]
+
+ stream.foreachRDD { rdd => allReceived ++= rdd.collect() }
+ ssc.start()
+ eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+ assert(allReceived.size === totalSent,
+ "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n"))
+
+ // Calculate all the record number collected in the StreamingListener.
+ assert(collector.numRecordsSubmitted.get() === totalSent)
+ assert(collector.numRecordsStarted.get() === totalSent)
+ assert(collector.numRecordsCompleted.get() === totalSent)
+ }
+ ssc.stop()
+ }
+
/** Get the generated offset ranges from the DirectKafkaStream */
private def getOffsetRanges[K, V](
kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
@@ -313,4 +352,22 @@ class DirectKafkaStreamSuite
object DirectKafkaStreamSuite {
val collectedData = new mutable.ArrayBuffer[String]()
var total = -1L
+
+ class InputInfoCollector extends StreamingListener {
+ val numRecordsSubmitted = new AtomicLong(0L)
+ val numRecordsStarted = new AtomicLong(0L)
+ val numRecordsCompleted = new AtomicLong(0L)
+
+ override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = {
+ numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords)
+ }
+
+ override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = {
+ numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords)
+ }
+
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
+ numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords)
+ }
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
index d2729fa70d..24cbb2bf9d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
@@ -192,8 +192,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
val latestReceiverNumRecords = latestBatchInfos.map(_.receiverNumRecords)
val streamIds = ssc.graph.getInputStreams().map(_.id)
streamIds.map { id =>
- val recordsOfParticularReceiver =
- latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration)
+ val recordsOfParticularReceiver =
+ latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration)
val distribution = Distribution(recordsOfParticularReceiver)
(id, distribution)
}.toMap