aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorIulian Dragos <jaguarul@gmail.com>2015-07-29 13:47:37 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-07-29 13:47:37 -0700
commit819be46e5a73f2d19230354ebba30c58538590f5 (patch)
treebbbfc1016fe8bdb553864cea1bb036bdc27188ae /streaming
parent069a4c414db4612d7bdb6f5615c1ba36998e5a49 (diff)
downloadspark-819be46e5a73f2d19230354ebba30c58538590f5.tar.gz
spark-819be46e5a73f2d19230354ebba30c58538590f5.tar.bz2
spark-819be46e5a73f2d19230354ebba30c58538590f5.zip
[SPARK-8977] [STREAMING] Defines the RateEstimator interface, and impements the RateController
Based on #7471. - [x] add a test that exercises the publish path from driver to receiver - [ ] remove Serializable from `RateController` and `RateEstimator` Author: Iulian Dragos <jaguarul@gmail.com> Author: François Garillot <francois@garillot.net> Closes #7600 from dragos/topic/streaming-bp/rate-controller and squashes the following commits: f168c94 [Iulian Dragos] Latest review round. 5125e60 [Iulian Dragos] Fix style. a2eb3b9 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into topic/streaming-bp/rate-controller 475e346 [Iulian Dragos] Latest round of reviews. e9fb45e [Iulian Dragos] - Add a test for checkpointing - fixed serialization for RateController.executionContext 715437a [Iulian Dragos] Review comments and added a `reset` call in ReceiverTrackerTest. e57c66b [Iulian Dragos] Added a couple of tests for the full scenario from driver to receivers, with several rate updates. b425d32 [Iulian Dragos] Removed DeveloperAPI, removed rateEstimator field, removed Noop rate estimator, changed logic for initialising rate estimator. 238cfc6 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into topic/streaming-bp/rate-controller 34a389d [Iulian Dragos] Various style changes and a first test for the rate controller. d32ca36 [François Garillot] [SPARK-8977][Streaming] Defines the RateEstimator interface, and implements the ReceiverRateController 8941cf9 [Iulian Dragos] Renames and other nitpicks. 162d9e5 [Iulian Dragos] Use Reflection for accessing truly private `executor` method and use the listener bus to know when receivers have registered (`onStart` is called before receivers have registered, leading to flaky behavior). 210f495 [Iulian Dragos] Revert "Added a few tests that measure the receiver’s rate." 0c51959 [Iulian Dragos] Added a few tests that measure the receiver’s rate. 261a051 [Iulian Dragos] - removed field to hold the current rate limit in rate limiter - made rate limit a Long and default to Long.MaxValue (consequence of the above) - removed custom `waitUntil` and replaced it by `eventually` cd1397d [Iulian Dragos] Add a test for the propagation of a new rate limit from driver to receivers. 6369b30 [Iulian Dragos] Merge pull request #15 from huitseeker/SPARK-8975 d15de42 [François Garillot] [SPARK-8975][Streaming] Adds Ratelimiter unit tests w.r.t. spark.streaming.receiver.maxRate 4721c7d [François Garillot] [SPARK-8975][Streaming] Add a mechanism to send a new rate from the driver to the block generator
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala26
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala90
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala59
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala28
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala103
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala10
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala41
9 files changed, 355 insertions, 15 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index d58c99a8ff..a6c4cd220e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -21,7 +21,9 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDDOperationScope
-import org.apache.spark.streaming.{Time, Duration, StreamingContext}
+import org.apache.spark.streaming.{Duration, StreamingContext, Time}
+import org.apache.spark.streaming.scheduler.RateController
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.Utils
/**
@@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
/** This is an unique identifier for the input stream. */
val id = ssc.getNewInputStreamId()
+ // Keep track of the freshest rate for this stream using the rateEstimator
+ protected[streaming] val rateController: Option[RateController] = None
+
/** A human-readable name of this InputDStream */
private[streaming] def name: String = {
// e.g. FlumePollingDStream -> "Flume polling stream"
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index a50f0efc03..646a8c3530 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -21,10 +21,11 @@ import scala.reflect.ClassTag
import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.BlockId
-import org.apache.spark.streaming._
+import org.apache.spark.streaming.{StreamingContext, Time}
import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD
import org.apache.spark.streaming.receiver.Receiver
-import org.apache.spark.streaming.scheduler.StreamInputInfo
+import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.streaming.util.WriteAheadLogUtils
/**
@@ -41,6 +42,17 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
extends InputDStream[T](ssc_) {
/**
+ * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
+ */
+ override protected[streaming] val rateController: Option[RateController] = {
+ if (RateController.isBackPressureEnabled(ssc.conf)) {
+ RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) }
+ } else {
+ None
+ }
+ }
+
+ /**
* Gets the receiver object that will be sent to the worker nodes
* to receive data. This method needs to defined by any specific implementation
* of a ReceiverInputDStream.
@@ -110,4 +122,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
}
Some(blockRDD)
}
+
+ /**
+ * A RateController that sends the new rate to receivers, via the receiver tracker.
+ */
+ private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator)
+ extends RateController(id, estimator) {
+ override def publish(rate: Long): Unit =
+ ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
+ }
}
+
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 4af9b6d3b5..58bdda7794 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -66,6 +66,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
}
eventLoop.start()
+ // attach rate controllers of input streams to receive batch completion updates
+ for {
+ inputDStream <- ssc.graph.getInputStreams
+ rateController <- inputDStream.rateController
+ } ssc.addStreamingListener(rateController)
+
listenerBus.start(ssc.sparkContext)
receiverTracker = new ReceiverTracker(ssc)
inputInfoTracker = new InputInfoTracker(ssc)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
new file mode 100644
index 0000000000..882ca0676b
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import java.io.ObjectInputStream
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.concurrent.{ExecutionContext, Future}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * A StreamingListener that receives batch completion updates, and maintains
+ * an estimate of the speed at which this stream should ingest messages,
+ * given an estimate computation from a `RateEstimator`
+ */
+private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator)
+ extends StreamingListener with Serializable {
+
+ init()
+
+ protected def publish(rate: Long): Unit
+
+ @transient
+ implicit private var executionContext: ExecutionContext = _
+
+ @transient
+ private var rateLimit: AtomicLong = _
+
+ /**
+ * An initialization method called both from the constructor and Serialization code.
+ */
+ private def init() {
+ executionContext = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update"))
+ rateLimit = new AtomicLong(-1L)
+ }
+
+ private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
+ ois.defaultReadObject()
+ init()
+ }
+
+ /**
+ * Compute the new rate limit and publish it asynchronously.
+ */
+ private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
+ Future[Unit] {
+ val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
+ newRate.foreach { s =>
+ rateLimit.set(s.toLong)
+ publish(getLatestRate())
+ }
+ }
+
+ def getLatestRate(): Long = rateLimit.get()
+
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+ val elements = batchCompleted.batchInfo.streamIdToInputInfo
+
+ for {
+ processingEnd <- batchCompleted.batchInfo.processingEndTime;
+ workDelay <- batchCompleted.batchInfo.processingDelay;
+ waitDelay <- batchCompleted.batchInfo.schedulingDelay;
+ elems <- elements.get(streamUID).map(_.numRecords)
+ } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
+ }
+}
+
+object RateController {
+ def isBackPressureEnabled(conf: SparkConf): Boolean =
+ conf.getBoolean("spark.streaming.backpressure.enable", false)
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
new file mode 100644
index 0000000000..a08685119e
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler.rate
+
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkException
+
+/**
+ * A component that estimates the rate at wich an InputDStream should ingest
+ * elements, based on updates at every batch completion.
+ */
+private[streaming] trait RateEstimator extends Serializable {
+
+ /**
+ * Computes the number of elements the stream attached to this `RateEstimator`
+ * should ingest per second, given an update on the size and completion
+ * times of the latest batch.
+ *
+ * @param time The timetamp of the current batch interval that just finished
+ * @param elements The number of elements that were processed in this batch
+ * @param processingDelay The time in ms that took for the job to complete
+ * @param schedulingDelay The time in ms that the job spent in the scheduling queue
+ */
+ def compute(
+ time: Long,
+ elements: Long,
+ processingDelay: Long,
+ schedulingDelay: Long): Option[Double]
+}
+
+object RateEstimator {
+
+ /**
+ * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`.
+ *
+ * @return None if there is no configured estimator, otherwise an instance of RateEstimator
+ * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any
+ * known estimators.
+ */
+ def create(conf: SparkConf): Option[RateEstimator] =
+ conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator =>
+ throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index d308ac05a5..67c2d90094 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text}
import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
+import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver}
import org.apache.spark.util.{Clock, ManualClock, Utils}
/**
@@ -391,6 +393,32 @@ class CheckpointSuite extends TestSuiteBase {
testCheckpointedOperation(input, operation, output, 7)
}
+ test("recovery maintains rate controller") {
+ ssc = new StreamingContext(conf, batchDuration)
+ ssc.checkpoint(checkpointDir)
+
+ val dstream = new RateLimitInputDStream(ssc) {
+ override val rateController =
+ Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+ }
+ SingletonTestRateReceiver.reset()
+
+ val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2))
+ output.register()
+ runStreams(ssc, 5, 5)
+
+ SingletonTestRateReceiver.reset()
+ ssc = new StreamingContext(checkpointDir)
+ ssc.start()
+ val outputNew = advanceTimeWithRealDelay(ssc, 2)
+
+ eventually(timeout(5.seconds)) {
+ assert(dstream.getCurrentRateLimit === Some(200))
+ }
+ ssc.stop()
+ ssc = null
+ }
+
// This tests whether file input stream remembers what files were seen before
// the master failure and uses them again to process a large window operation.
// It also tests whether batches, whose processing was incomplete due to the
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
new file mode 100644
index 0000000000..921da773f6
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+
+import org.scalatest.Matchers._
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+
+class RateControllerSuite extends TestSuiteBase {
+
+ override def useManualClock: Boolean = false
+
+ test("rate controller publishes updates") {
+ val ssc = new StreamingContext(conf, batchDuration)
+ withStreamingContext(ssc) { ssc =>
+ val dstream = new RateLimitInputDStream(ssc)
+ dstream.register()
+ ssc.start()
+
+ eventually(timeout(10.seconds)) {
+ assert(dstream.publishCalls > 0)
+ }
+ }
+ }
+
+ test("publish rates reach receivers") {
+ val ssc = new StreamingContext(conf, batchDuration)
+ withStreamingContext(ssc) { ssc =>
+ val dstream = new RateLimitInputDStream(ssc) {
+ override val rateController =
+ Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+ }
+ dstream.register()
+ SingletonTestRateReceiver.reset()
+ ssc.start()
+
+ eventually(timeout(10.seconds)) {
+ assert(dstream.getCurrentRateLimit === Some(200))
+ }
+ }
+ }
+
+ test("multiple publish rates reach receivers") {
+ val ssc = new StreamingContext(conf, batchDuration)
+ withStreamingContext(ssc) { ssc =>
+ val rates = Seq(100L, 200L, 300L)
+
+ val dstream = new RateLimitInputDStream(ssc) {
+ override val rateController =
+ Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*)))
+ }
+ SingletonTestRateReceiver.reset()
+ dstream.register()
+
+ val observedRates = mutable.HashSet.empty[Long]
+ ssc.start()
+
+ eventually(timeout(20.seconds)) {
+ dstream.getCurrentRateLimit.foreach(observedRates += _)
+ // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver
+ observedRates should contain theSameElementsAs (rates :+ Long.MaxValue)
+ }
+ }
+ }
+}
+
+private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator {
+ private var idx: Int = 0
+
+ private def nextRate(): Double = {
+ val rate = rates(idx)
+ idx = (idx + 1) % rates.size
+ rate
+ }
+
+ def compute(
+ time: Long,
+ elements: Long,
+ processingDelay: Long,
+ schedulingDelay: Long): Option[Double] = Some(nextRate())
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
index 93f920fdc7..0418d776ec 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
@@ -64,7 +64,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
test("scheduleReceivers: " +
"schedule receivers evenly when there are more receivers than executors") {
- val receivers = (0 until 6).map(new DummyReceiver(_))
+ val receivers = (0 until 6).map(new RateTestReceiver(_))
val executors = (10000 until 10003).map(port => s"localhost:${port}")
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
val numReceiversOnExecutor = mutable.HashMap[String, Int]()
@@ -79,7 +79,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
test("scheduleReceivers: " +
"schedule receivers evenly when there are more executors than receivers") {
- val receivers = (0 until 3).map(new DummyReceiver(_))
+ val receivers = (0 until 3).map(new RateTestReceiver(_))
val executors = (10000 until 10006).map(port => s"localhost:${port}")
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
val numReceiversOnExecutor = mutable.HashMap[String, Int]()
@@ -94,8 +94,8 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
}
test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") {
- val receivers = (0 until 3).map(new DummyReceiver(_)) ++
- (3 until 6).map(new DummyReceiver(_, Some("localhost")))
+ val receivers = (0 until 3).map(new RateTestReceiver(_)) ++
+ (3 until 6).map(new RateTestReceiver(_, Some("localhost")))
val executors = (10000 until 10003).map(port => s"localhost:${port}") ++
(10003 until 10006).map(port => s"localhost2:${port}")
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
@@ -121,7 +121,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
}
test("scheduleReceivers: return empty scheduled executors if no executors") {
- val receivers = (0 until 3).map(new DummyReceiver(_))
+ val receivers = (0 until 3).map(new RateTestReceiver(_))
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty)
scheduledExecutors.foreach { case (receiverId, executors) =>
assert(executors.isEmpty)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index b039233f36..aff8b53f75 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -43,6 +43,7 @@ class ReceiverTrackerSuite extends TestSuiteBase {
ssc.addStreamingListener(ReceiverStartedWaiter)
ssc.scheduler.listenerBus.start(ssc.sc)
+ SingletonTestRateReceiver.reset()
val newRateLimit = 100L
val inputDStream = new RateLimitInputDStream(ssc)
@@ -62,36 +63,62 @@ class ReceiverTrackerSuite extends TestSuiteBase {
}
}
-/** An input DStream with a hard-coded receiver that gives access to internals for testing. */
-private class RateLimitInputDStream(@transient ssc_ : StreamingContext)
+/**
+ * An input DStream with a hard-coded receiver that gives access to internals for testing.
+ *
+ * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test,
+ * or otherwise you may get {{{NotSerializableException}}} when trying to serialize
+ * the receiver.
+ * @see [[[SingletonDummyReceiver]]].
+ */
+private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext)
extends ReceiverInputDStream[Int](ssc_) {
- override def getReceiver(): DummyReceiver = SingletonDummyReceiver
+ override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver
def getCurrentRateLimit: Option[Long] = {
invokeExecutorMethod.getCurrentRateLimit
}
+ @volatile
+ var publishCalls = 0
+
+ override val rateController: Option[RateController] = {
+ Some(new RateController(id, new ConstantEstimator(100.0)) {
+ override def publish(rate: Long): Unit = {
+ publishCalls += 1
+ }
+ })
+ }
+
private def invokeExecutorMethod: ReceiverSupervisor = {
val c = classOf[Receiver[_]]
val ex = c.getDeclaredMethod("executor")
ex.setAccessible(true)
- ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor]
+ ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor]
}
}
/**
- * A Receiver as an object so we can read its rate limit.
+ * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when
+ * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being
+ * serialized when receivers are installed on executors.
*
* @note It's necessary to be a top-level object, or else serialization would create another
* one on the executor side and we won't be able to read its rate limit.
*/
-private object SingletonDummyReceiver extends DummyReceiver(0)
+private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) {
+
+ /** Reset the object to be usable in another test. */
+ def reset(): Unit = {
+ executor_ = null
+ }
+}
/**
* Dummy receiver implementation
*/
-private class DummyReceiver(receiverId: Int, host: Option[String] = None)
+private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None)
extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
setReceiverId(receiverId)