aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala44
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala9
-rw-r--r--external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java2
-rw-r--r--external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java2
-rw-r--r--external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java2
-rw-r--r--external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala68
-rw-r--r--project/MimaExcludes.scala4
7 files changed, 101 insertions, 30 deletions
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index 54d8c8b03f..0eaaf408c0 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -89,23 +89,32 @@ class DirectKafkaInputDStream[
private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
"spark.streaming.kafka.maxRatePerPartition", 0)
- protected def maxMessagesPerPartition: Option[Long] = {
+
+ protected[streaming] def maxMessagesPerPartition(
+ offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = {
val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
- val numPartitions = currentOffsets.keys.size
-
- val effectiveRateLimitPerPartition = estimatedRateLimit
- .filter(_ > 0)
- .map { limit =>
- if (maxRateLimitPerPartition > 0) {
- Math.min(maxRateLimitPerPartition, (limit / numPartitions))
- } else {
- limit / numPartitions
+
+ // calculate a per-partition rate limit based on current lag
+ val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
+ case Some(rate) =>
+ val lagPerPartition = offsets.map { case (tp, offset) =>
+ tp -> Math.max(offset - currentOffsets(tp), 0)
+ }
+ val totalLag = lagPerPartition.values.sum
+
+ lagPerPartition.map { case (tp, lag) =>
+ val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
+ tp -> (if (maxRateLimitPerPartition > 0) {
+ Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
}
- }.getOrElse(maxRateLimitPerPartition)
+ case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition }
+ }
- if (effectiveRateLimitPerPartition > 0) {
+ if (effectiveRateLimitPerPartition.values.sum > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
- Some((secsPerBatch * effectiveRateLimitPerPartition).toLong)
+ Some(effectiveRateLimitPerPartition.map {
+ case (tp, limit) => tp -> (secsPerBatch * limit).toLong
+ })
} else {
None
}
@@ -134,9 +143,12 @@ class DirectKafkaInputDStream[
// limits the maximum number of messages per partition
protected def clamp(
leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = {
- maxMessagesPerPartition.map { mmp =>
- leaderOffsets.map { case (tp, lo) =>
- tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset))
+ val offsets = leaderOffsets.mapValues(lo => lo.offset)
+
+ maxMessagesPerPartition(offsets).map { mmp =>
+ mmp.map { case (tp, messages) =>
+ val lo = leaderOffsets(tp)
+ tp -> lo.copy(offset = Math.min(currentOffsets(tp) + messages, lo.offset))
}
}.getOrElse(leaderOffsets)
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index a76fa6671a..a5ea1d6d28 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -152,12 +152,15 @@ private[kafka] class KafkaTestUtils extends Logging {
}
/** Create a Kafka topic and wait until it is propagated to the whole cluster */
- def createTopic(topic: String): Unit = {
- AdminUtils.createTopic(zkClient, topic, 1, 1)
+ def createTopic(topic: String, partitions: Int): Unit = {
+ AdminUtils.createTopic(zkClient, topic, partitions, 1)
// wait until metadata is propagated
- waitUntilMetadataIsPropagated(topic, 0)
+ (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) }
}
+ /** Single-argument version for backwards compatibility */
+ def createTopic(topic: String): Unit = createTopic(topic, 1)
+
/** Java-friendly function for sending messages to the Kafka broker */
def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = {
sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*))
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
index 4891e4f4a1..fa6b0dbc8c 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -168,7 +168,7 @@ public class JavaDirectKafkaStreamSuite implements Serializable {
private String[] createTopicAndSendData(String topic) {
String[] data = { topic + "-1", topic + "-2", topic + "-3"};
- kafkaTestUtils.createTopic(topic);
+ kafkaTestUtils.createTopic(topic, 1);
kafkaTestUtils.sendMessages(topic, data);
return data;
}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
index afcc6cfccd..c41b6297b0 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -149,7 +149,7 @@ public class JavaKafkaRDDSuite implements Serializable {
private String[] createTopicAndSendData(String topic) {
String[] data = { topic + "-1", topic + "-2", topic + "-3"};
- kafkaTestUtils.createTopic(topic);
+ kafkaTestUtils.createTopic(topic, 1);
kafkaTestUtils.sendMessages(topic, data);
return data;
}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index 617c92a008..868df64e8c 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -76,7 +76,7 @@ public class JavaKafkaStreamSuite implements Serializable {
sent.put("b", 3);
sent.put("c", 10);
- kafkaTestUtils.createTopic(topic);
+ kafkaTestUtils.createTopic(topic, 1);
kafkaTestUtils.sendMessages(topic, sent);
Map<String, String> kafkaParams = new HashMap<>();
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index 8398178e9b..b2c81d1534 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -353,10 +353,38 @@ class DirectKafkaStreamSuite
ssc.stop()
}
+ test("maxMessagesPerPartition with backpressure disabled") {
+ val topic = "maxMessagesPerPartition"
+ val kafkaStream = getDirectKafkaStream(topic, None)
+
+ val input = Map(TopicAndPartition(topic, 0) -> 50L, TopicAndPartition(topic, 1) -> 50L)
+ assert(kafkaStream.maxMessagesPerPartition(input).get ==
+ Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L))
+ }
+
+ test("maxMessagesPerPartition with no lag") {
+ val topic = "maxMessagesPerPartition"
+ val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100))
+ val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+ val input = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L)
+ assert(kafkaStream.maxMessagesPerPartition(input).isEmpty)
+ }
+
+ test("maxMessagesPerPartition respects max rate") {
+ val topic = "maxMessagesPerPartition"
+ val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000))
+ val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+ val input = Map(TopicAndPartition(topic, 0) -> 1000L, TopicAndPartition(topic, 1) -> 1000L)
+ assert(kafkaStream.maxMessagesPerPartition(input).get ==
+ Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L))
+ }
+
test("using rate controller") {
val topic = "backpressure"
- val topicPartition = TopicAndPartition(topic, 0)
- kafkaTestUtils.createTopic(topic)
+ val topicPartitions = Set(TopicAndPartition(topic, 0), TopicAndPartition(topic, 1))
+ kafkaTestUtils.createTopic(topic, 2)
val kafkaParams = Map(
"metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"auto.offset.reset" -> "smallest"
@@ -364,8 +392,8 @@ class DirectKafkaStreamSuite
val batchIntervalMilliseconds = 100
val estimator = new ConstantEstimator(100)
- val messageKeys = (1 to 200).map(_.toString)
- val messages = messageKeys.map((_, 1)).toMap
+ val messages = Map("foo" -> 200)
+ kafkaTestUtils.sendMessages(topic, messages)
val sparkConf = new SparkConf()
// Safe, even with streaming, because we're using the direct API.
@@ -380,11 +408,11 @@ class DirectKafkaStreamSuite
val kafkaStream = withClue("Error creating direct stream") {
val kc = new KafkaCluster(kafkaParams)
val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
- val m = kc.getEarliestLeaderOffsets(Set(topicPartition))
+ val m = kc.getEarliestLeaderOffsets(topicPartitions)
.fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset))
new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
- ssc, kafkaParams, m, messageHandler) {
+ ssc, kafkaParams, m, messageHandler) {
override protected[streaming] val rateController =
Some(new DirectKafkaRateController(id, estimator))
}
@@ -405,13 +433,12 @@ class DirectKafkaStreamSuite
ssc.start()
// Try different rate limits.
- // Send data to Kafka and wait for arrays of data to appear matching the rate.
+ // Wait for arrays of data to appear matching the rate.
Seq(100, 50, 20).foreach { rate =>
collectedData.clear() // Empty this buffer on each pass.
estimator.updateRate(rate) // Set a new rate.
// Expect blocks of data equal to "rate", scaled by the interval length in secs.
val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001)
- kafkaTestUtils.sendMessages(topic, messages)
eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) {
// Assert that rate estimator values are used to determine maxMessagesPerPartition.
// Funky "-" in message makes the complete assertion message read better.
@@ -430,6 +457,25 @@ class DirectKafkaStreamSuite
rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges
}.toSeq.sortBy { _._1 }
}
+
+ private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = {
+ val batchIntervalMilliseconds = 100
+
+ val sparkConf = new SparkConf()
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+
+ val earliestOffsets = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L)
+ val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
+ new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
+ ssc, Map[String, String](), earliestOffsets, messageHandler) {
+ override protected[streaming] val rateController = mockRateController
+ }
+ }
}
object DirectKafkaStreamSuite {
@@ -468,3 +514,9 @@ private[streaming] class ConstantEstimator(@volatile private var rate: Long)
processingDelay: Long,
schedulingDelay: Long): Option[Double] = Some(rate)
}
+
+private[streaming] class ConstantRateController(id: Int, estimator: RateEstimator, rate: Long)
+ extends RateController(id, estimator) {
+ override def publish(rate: Long): Unit = ()
+ override def getLatestRate(): Long = rate
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9ce37fc753..983f71684c 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -288,6 +288,10 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
+ ) ++ Seq(
+ // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
)
case v if v.startsWith("1.6") =>
Seq(