aboutsummaryrefslogtreecommitdiff
path: root/external/kafka/src/main/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'external/kafka/src/main/scala/org')
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala44
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala9
2 files changed, 34 insertions, 19 deletions
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index 54d8c8b03f..0eaaf408c0 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -89,23 +89,32 @@ class DirectKafkaInputDStream[
private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
"spark.streaming.kafka.maxRatePerPartition", 0)
- protected def maxMessagesPerPartition: Option[Long] = {
+
+ protected[streaming] def maxMessagesPerPartition(
+ offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = {
val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
- val numPartitions = currentOffsets.keys.size
-
- val effectiveRateLimitPerPartition = estimatedRateLimit
- .filter(_ > 0)
- .map { limit =>
- if (maxRateLimitPerPartition > 0) {
- Math.min(maxRateLimitPerPartition, (limit / numPartitions))
- } else {
- limit / numPartitions
+
+ // calculate a per-partition rate limit based on current lag
+ val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
+ case Some(rate) =>
+ val lagPerPartition = offsets.map { case (tp, offset) =>
+ tp -> Math.max(offset - currentOffsets(tp), 0)
+ }
+ val totalLag = lagPerPartition.values.sum
+
+ lagPerPartition.map { case (tp, lag) =>
+ val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
+ tp -> (if (maxRateLimitPerPartition > 0) {
+ Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
}
- }.getOrElse(maxRateLimitPerPartition)
+ case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition }
+ }
- if (effectiveRateLimitPerPartition > 0) {
+ if (effectiveRateLimitPerPartition.values.sum > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
- Some((secsPerBatch * effectiveRateLimitPerPartition).toLong)
+ Some(effectiveRateLimitPerPartition.map {
+ case (tp, limit) => tp -> (secsPerBatch * limit).toLong
+ })
} else {
None
}
@@ -134,9 +143,12 @@ class DirectKafkaInputDStream[
// limits the maximum number of messages per partition
protected def clamp(
leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = {
- maxMessagesPerPartition.map { mmp =>
- leaderOffsets.map { case (tp, lo) =>
- tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset))
+ val offsets = leaderOffsets.mapValues(lo => lo.offset)
+
+ maxMessagesPerPartition(offsets).map { mmp =>
+ mmp.map { case (tp, messages) =>
+ val lo = leaderOffsets(tp)
+ tp -> lo.copy(offset = Math.min(currentOffsets(tp) + messages, lo.offset))
}
}.getOrElse(leaderOffsets)
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index a76fa6671a..a5ea1d6d28 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -152,12 +152,15 @@ private[kafka] class KafkaTestUtils extends Logging {
}
/** Create a Kafka topic and wait until it is propagated to the whole cluster */
- def createTopic(topic: String): Unit = {
- AdminUtils.createTopic(zkClient, topic, 1, 1)
+ def createTopic(topic: String, partitions: Int): Unit = {
+ AdminUtils.createTopic(zkClient, topic, partitions, 1)
// wait until metadata is propagated
- waitUntilMetadataIsPropagated(topic, 0)
+ (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) }
}
+ /** Single-argument version for backwards compatibility */
+ def createTopic(topic: String): Unit = createTopic(topic, 1)
+
/** Java-friendly function for sending messages to the Kafka broker */
def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = {
sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*))