aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorIulian Dragos <jaguarul@gmail.com>2015-07-31 12:04:03 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-07-31 12:04:03 -0700
commit0a1d2ca42c8b31d6b0e70163795f0185d4622f87 (patch)
treed453ae5039eeccfba70958dadab2c98766126ede /streaming
parente8bdcdeabb2df139a656f86686cdb53c891b1f4b (diff)
downloadspark-0a1d2ca42c8b31d6b0e70163795f0185d4622f87.tar.gz
spark-0a1d2ca42c8b31d6b0e70163795f0185d4622f87.tar.bz2
spark-0a1d2ca42c8b31d6b0e70163795f0185d4622f87.zip
[SPARK-8979] Add a PID based rate estimator
Based on #7600 /cc tdas Author: Iulian Dragos <jaguarul@gmail.com> Author: François Garillot <francois@garillot.net> Closes #7648 from dragos/topic/streaming-bp/pid and squashes the following commits: aa5b097 [Iulian Dragos] Add more comments, made all PID constant parameters positive, a couple more tests. 93b74f8 [Iulian Dragos] Better explanation of historicalError. 7975b0c [Iulian Dragos] Add configuration for PID. 26cfd78 [Iulian Dragos] A couple of variable renames. d0bdf7c [Iulian Dragos] Update to latest version of the code, various style and name improvements. d58b845 [François Garillot] [SPARK-8979][Streaming] Implements a PIDRateEstimator
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala124
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala18
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala137
4 files changed, 276 insertions, 5 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index 646a8c3530..670ef8d296 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -46,7 +46,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
*/
override protected[streaming] val rateController: Option[RateController] = {
if (RateController.isBackPressureEnabled(ssc.conf)) {
- RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) }
+ Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration)))
} else {
None
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala
new file mode 100644
index 0000000000..6ae56a68ad
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler.rate
+
+/**
+ * Implements a proportional-integral-derivative (PID) controller which acts on
+ * the speed of ingestion of elements into Spark Streaming. A PID controller works
+ * by calculating an '''error''' between a measured output and a desired value. In the
+ * case of Spark Streaming the error is the difference between the measured processing
+ * rate (number of elements/processing delay) and the previous rate.
+ *
+ * @see https://en.wikipedia.org/wiki/PID_controller
+ *
+ * @param batchDurationMillis the batch duration, in milliseconds
+ * @param proportional how much the correction should depend on the current
+ * error. This term usually provides the bulk of correction and should be positive or zero.
+ * A value too large would make the controller overshoot the setpoint, while a small value
+ * would make the controller too insensitive. The default value is 1.
+ * @param integral how much the correction should depend on the accumulation
+ * of past errors. This value should be positive or 0. This term accelerates the movement
+ * towards the desired value, but a large value may lead to overshooting. The default value
+ * is 0.2.
+ * @param derivative how much the correction should depend on a prediction
+ * of future errors, based on current rate of change. This value should be positive or 0.
+ * This term is not used very often, as it impacts stability of the system. The default
+ * value is 0.
+ */
+private[streaming] class PIDRateEstimator(
+ batchIntervalMillis: Long,
+ proportional: Double = 1D,
+ integral: Double = .2D,
+ derivative: Double = 0D)
+ extends RateEstimator {
+
+ private var firstRun: Boolean = true
+ private var latestTime: Long = -1L
+ private var latestRate: Double = -1D
+ private var latestError: Double = -1L
+
+ require(
+ batchIntervalMillis > 0,
+ s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.")
+ require(
+ proportional >= 0,
+ s"Proportional term $proportional in PIDRateEstimator should be >= 0.")
+ require(
+ integral >= 0,
+ s"Integral term $integral in PIDRateEstimator should be >= 0.")
+ require(
+ derivative >= 0,
+ s"Derivative term $derivative in PIDRateEstimator should be >= 0.")
+
+
+ def compute(time: Long, // in milliseconds
+ numElements: Long,
+ processingDelay: Long, // in milliseconds
+ schedulingDelay: Long // in milliseconds
+ ): Option[Double] = {
+
+ this.synchronized {
+ if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) {
+
+ // in seconds, should be close to batchDuration
+ val delaySinceUpdate = (time - latestTime).toDouble / 1000
+
+ // in elements/second
+ val processingRate = numElements.toDouble / processingDelay * 1000
+
+ // In our system `error` is the difference between the desired rate and the measured rate
+ // based on the latest batch information. We consider the desired rate to be latest rate,
+ // which is what this estimator calculated for the previous batch.
+ // in elements/second
+ val error = latestRate - processingRate
+
+ // The error integral, based on schedulingDelay as an indicator for accumulated errors.
+ // A scheduling delay s corresponds to s * processingRate overflowing elements. Those
+ // are elements that couldn't be processed in previous batches, leading to this delay.
+ // In the following, we assume the processingRate didn't change too much.
+ // From the number of overflowing elements we can calculate the rate at which they would be
+ // processed by dividing it by the batch interval. This rate is our "historical" error,
+ // or integral part, since if we subtracted this rate from the previous "calculated rate",
+ // there wouldn't have been any overflowing elements, and the scheduling delay would have
+ // been zero.
+ // (in elements/second)
+ val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis
+
+ // in elements/(second ^ 2)
+ val dError = (error - latestError) / delaySinceUpdate
+
+ val newRate = (latestRate - proportional * error -
+ integral * historicalError -
+ derivative * dError).max(0.0)
+ latestTime = time
+ if (firstRun) {
+ latestRate = processingRate
+ latestError = 0D
+ firstRun = false
+
+ None
+ } else {
+ latestRate = newRate
+ latestError = error
+
+ Some(newRate)
+ }
+ } else None
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
index a08685119e..17ccebc1ed 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
@@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler.rate
import org.apache.spark.SparkConf
import org.apache.spark.SparkException
+import org.apache.spark.streaming.Duration
/**
* A component that estimates the rate at wich an InputDStream should ingest
@@ -48,12 +49,21 @@ object RateEstimator {
/**
* Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`.
*
- * @return None if there is no configured estimator, otherwise an instance of RateEstimator
+ * The only known estimator right now is `pid`.
+ *
+ * @return An instance of RateEstimator
* @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any
* known estimators.
*/
- def create(conf: SparkConf): Option[RateEstimator] =
- conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator =>
- throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
+ def create(conf: SparkConf, batchInterval: Duration): RateEstimator =
+ conf.get("spark.streaming.backpressure.rateEstimator", "pid") match {
+ case "pid" =>
+ val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0)
+ val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2)
+ val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0)
+ new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived)
+
+ case estimator =>
+ throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala
new file mode 100644
index 0000000000..97c32d8f2d
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler.rate
+
+import scala.util.Random
+
+import org.scalatest.Inspectors.forAll
+import org.scalatest.Matchers
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.streaming.Seconds
+
+class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
+
+ test("the right estimator is created") {
+ val conf = new SparkConf
+ conf.set("spark.streaming.backpressure.rateEstimator", "pid")
+ val pid = RateEstimator.create(conf, Seconds(1))
+ pid.getClass should equal(classOf[PIDRateEstimator])
+ }
+
+ test("estimator checks ranges") {
+ intercept[IllegalArgumentException] {
+ new PIDRateEstimator(0, 1, 2, 3)
+ }
+ intercept[IllegalArgumentException] {
+ new PIDRateEstimator(100, -1, 2, 3)
+ }
+ intercept[IllegalArgumentException] {
+ new PIDRateEstimator(100, 0, -1, 3)
+ }
+ intercept[IllegalArgumentException] {
+ new PIDRateEstimator(100, 0, 0, -1)
+ }
+ }
+
+ private def createDefaultEstimator: PIDRateEstimator = {
+ new PIDRateEstimator(20, 1D, 0D, 0D)
+ }
+
+ test("first bound is None") {
+ val p = createDefaultEstimator
+ p.compute(0, 10, 10, 0) should equal(None)
+ }
+
+ test("second bound is rate") {
+ val p = createDefaultEstimator
+ p.compute(0, 10, 10, 0)
+ // 1000 elements / s
+ p.compute(10, 10, 10, 0) should equal(Some(1000))
+ }
+
+ test("works even with no time between updates") {
+ val p = createDefaultEstimator
+ p.compute(0, 10, 10, 0)
+ p.compute(10, 10, 10, 0)
+ p.compute(10, 10, 10, 0) should equal(None)
+ }
+
+ test("bound is never negative") {
+ val p = new PIDRateEstimator(20, 1D, 1D, 0D)
+ // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing
+ // this might point the estimator to try and decrease the bound, but we test it never
+ // goes below zero, which would be nonsensical.
+ val times = List.tabulate(50)(x => x * 20) // every 20ms
+ val elements = List.fill(50)(0) // no processing
+ val proc = List.fill(50)(20) // 20ms of processing
+ val sched = List.fill(50)(100) // strictly positive accumulation
+ val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
+ res.head should equal(None)
+ res.tail should equal(List.fill(49)(Some(0D)))
+ }
+
+ test("with no accumulated or positive error, |I| > 0, follow the processing speed") {
+ val p = new PIDRateEstimator(20, 1D, 1D, 0D)
+ // prepare a series of batch updates, one every 20ms with an increasing number of processed
+ // elements in each batch, but constant processing time, and no accumulated error. Even though
+ // the integral part is non-zero, the estimated rate should follow only the proportional term
+ val times = List.tabulate(50)(x => x * 20) // every 20ms
+ val elements = List.tabulate(50)(x => x * 20) // increasing
+ val proc = List.fill(50)(20) // 20ms of processing
+ val sched = List.fill(50)(0)
+ val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
+ res.head should equal(None)
+ res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail)
+ }
+
+ test("with no accumulated but some positive error, |I| > 0, follow the processing speed") {
+ val p = new PIDRateEstimator(20, 1D, 1D, 0D)
+ // prepare a series of batch updates, one every 20ms with an decreasing number of processed
+ // elements in each batch, but constant processing time, and no accumulated error. Even though
+ // the integral part is non-zero, the estimated rate should follow only the proportional term,
+ // asking for less and less elements
+ val times = List.tabulate(50)(x => x * 20) // every 20ms
+ val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing
+ val proc = List.fill(50)(20) // 20ms of processing
+ val sched = List.fill(50)(0)
+ val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
+ res.head should equal(None)
+ res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail)
+ }
+
+ test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") {
+ val p = new PIDRateEstimator(20, 1D, .01D, 0D)
+ val times = List.tabulate(50)(x => x * 20) // every 20ms
+ val rng = new Random()
+ val elements = List.tabulate(50)(x => rng.nextInt(1000))
+ val procDelayMs = 20
+ val proc = List.fill(50)(procDelayMs) // 20ms of processing
+ val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait
+ val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000)
+
+ val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
+ res.head should equal(None)
+ forAll(List.range(1, 50)) { (n) =>
+ res(n) should not be None
+ if (res(n).get > 0 && sched(n) > 0) {
+ res(n).get should be < speeds(n)
+ }
+ }
+ }
+}