From 819be46e5a73f2d19230354ebba30c58538590f5 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Wed, 29 Jul 2015 13:47:37 -0700 Subject: [SPARK-8977] [STREAMING] Defines the RateEstimator interface, and impements the RateController MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on #7471. - [x] add a test that exercises the publish path from driver to receiver - [ ] remove Serializable from `RateController` and `RateEstimator` Author: Iulian Dragos Author: François Garillot Closes #7600 from dragos/topic/streaming-bp/rate-controller and squashes the following commits: f168c94 [Iulian Dragos] Latest review round. 5125e60 [Iulian Dragos] Fix style. a2eb3b9 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into topic/streaming-bp/rate-controller 475e346 [Iulian Dragos] Latest round of reviews. e9fb45e [Iulian Dragos] - Add a test for checkpointing - fixed serialization for RateController.executionContext 715437a [Iulian Dragos] Review comments and added a `reset` call in ReceiverTrackerTest. e57c66b [Iulian Dragos] Added a couple of tests for the full scenario from driver to receivers, with several rate updates. b425d32 [Iulian Dragos] Removed DeveloperAPI, removed rateEstimator field, removed Noop rate estimator, changed logic for initialising rate estimator. 238cfc6 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into topic/streaming-bp/rate-controller 34a389d [Iulian Dragos] Various style changes and a first test for the rate controller. d32ca36 [François Garillot] [SPARK-8977][Streaming] Defines the RateEstimator interface, and implements the ReceiverRateController 8941cf9 [Iulian Dragos] Renames and other nitpicks. 162d9e5 [Iulian Dragos] Use Reflection for accessing truly private `executor` method and use the listener bus to know when receivers have registered (`onStart` is called before receivers have registered, leading to flaky behavior). 210f495 [Iulian Dragos] Revert "Added a few tests that measure the receiver’s rate." 0c51959 [Iulian Dragos] Added a few tests that measure the receiver’s rate. 261a051 [Iulian Dragos] - removed field to hold the current rate limit in rate limiter - made rate limit a Long and default to Long.MaxValue (consequence of the above) - removed custom `waitUntil` and replaced it by `eventually` cd1397d [Iulian Dragos] Add a test for the propagation of a new rate limit from driver to receivers. 6369b30 [Iulian Dragos] Merge pull request #15 from huitseeker/SPARK-8975 d15de42 [François Garillot] [SPARK-8975][Streaming] Adds Ratelimiter unit tests w.r.t. spark.streaming.receiver.maxRate 4721c7d [François Garillot] [SPARK-8975][Streaming] Add a mechanism to send a new rate from the driver to the block generator --- .../spark/streaming/dstream/InputDStream.scala | 7 +- .../streaming/dstream/ReceiverInputDStream.scala | 26 +++++- .../spark/streaming/scheduler/JobScheduler.scala | 6 ++ .../spark/streaming/scheduler/RateController.scala | 90 ++++++++++++++++++ .../streaming/scheduler/rate/RateEstimator.scala | 59 ++++++++++++ .../apache/spark/streaming/CheckpointSuite.scala | 28 ++++++ .../streaming/scheduler/RateControllerSuite.scala | 103 +++++++++++++++++++++ .../scheduler/ReceiverSchedulingPolicySuite.scala | 10 +- .../streaming/scheduler/ReceiverTrackerSuite.scala | 41 ++++++-- 9 files changed, 355 insertions(+), 15 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala (limited to 'streaming') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index d58c99a8ff..a6c4cd220e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.scheduler.RateController +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils /** @@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + // Keep track of the freshest rate for this stream using the rateEstimator + protected[streaming] val rateController: Option[RateController] = None + /** A human-readable name of this InputDStream */ private[streaming] def name: String = { // e.g. FlumePollingDStream -> "Flume polling stream" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a50f0efc03..646a8c3530 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,10 +21,11 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId -import org.apache.spark.streaming._ +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -40,6 +41,17 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } + } else { + None + } + } + /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation @@ -110,4 +122,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } + + /** + * A RateController that sends the new rate to receivers, via the receiver tracker. + */ + private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = + ssc.scheduler.receiverTracker.sendRateUpdate(id, rate) + } } + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 4af9b6d3b5..58bdda7794 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -66,6 +66,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } eventLoop.start() + // attach rate controllers of input streams to receive batch completion updates + for { + inputDStream <- ssc.graph.getInputStreams + rateController <- inputDStream.rateController + } ssc.addStreamingListener(rateController) + listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala new file mode 100644 index 0000000000..882ca0676b --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.io.ObjectInputStream +import java.util.concurrent.atomic.AtomicLong + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A StreamingListener that receives batch completion updates, and maintains + * an estimate of the speed at which this stream should ingest messages, + * given an estimate computation from a `RateEstimator` + */ +private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { + + init() + + protected def publish(rate: Long): Unit + + @transient + implicit private var executionContext: ExecutionContext = _ + + @transient + private var rateLimit: AtomicLong = _ + + /** + * An initialization method called both from the constructor and Serialization code. + */ + private def init() { + executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update")) + rateLimit = new AtomicLong(-1L) + } + + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + init() + } + + /** + * Compute the new rate limit and publish it asynchronously. + */ + private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = + Future[Unit] { + val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay) + newRate.foreach { s => + rateLimit.set(s.toLong) + publish(getLatestRate()) + } + } + + def getLatestRate(): Long = rateLimit.get() + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + val elements = batchCompleted.batchInfo.streamIdToInputInfo + + for { + processingEnd <- batchCompleted.batchInfo.processingEndTime; + workDelay <- batchCompleted.batchInfo.processingDelay; + waitDelay <- batchCompleted.batchInfo.schedulingDelay; + elems <- elements.get(streamUID).map(_.numRecords) + } computeAndPublish(processingEnd, elems, workDelay, waitDelay) + } +} + +object RateController { + def isBackPressureEnabled(conf: SparkConf): Boolean = + conf.getBoolean("spark.streaming.backpressure.enable", false) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala new file mode 100644 index 0000000000..a08685119e --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.SparkConf +import org.apache.spark.SparkException + +/** + * A component that estimates the rate at wich an InputDStream should ingest + * elements, based on updates at every batch completion. + */ +private[streaming] trait RateEstimator extends Serializable { + + /** + * Computes the number of elements the stream attached to this `RateEstimator` + * should ingest per second, given an update on the size and completion + * times of the latest batch. + * + * @param time The timetamp of the current batch interval that just finished + * @param elements The number of elements that were processed in this batch + * @param processingDelay The time in ms that took for the job to complete + * @param schedulingDelay The time in ms that the job spent in the scheduling queue + */ + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] +} + +object RateEstimator { + + /** + * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * + * @return None if there is no configured estimator, otherwise an instance of RateEstimator + * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any + * known estimators. + */ + def create(conf: SparkConf): Option[RateEstimator] = + conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index d308ac05a5..67c2d90094 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -391,6 +393,32 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery maintains rate controller") { + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + SingletonTestRateReceiver.reset() + + val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) + output.register() + runStreams(ssc, 5, 5) + + SingletonTestRateReceiver.reset() + ssc = new StreamingContext(checkpointDir) + ssc.start() + val outputNew = advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(5.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + ssc.stop() + ssc = null + } + // This tests whether file input stream remembers what files were seen before // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala new file mode 100644 index 0000000000..921da773f6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +class RateControllerSuite extends TestSuiteBase { + + override def useManualClock: Boolean = false + + test("rate controller publishes updates") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) + dstream.register() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.publishCalls > 0) + } + } + } + + test("publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + dstream.register() + SingletonTestRateReceiver.reset() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + } + } + + test("multiple publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val rates = Seq(100L, 200L, 300L) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*))) + } + SingletonTestRateReceiver.reset() + dstream.register() + + val observedRates = mutable.HashSet.empty[Long] + ssc.start() + + eventually(timeout(20.seconds)) { + dstream.getCurrentRateLimit.foreach(observedRates += _) + // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver + observedRates should contain theSameElementsAs (rates :+ Long.MaxValue) + } + } + } +} + +private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator { + private var idx: Int = 0 + + private def nextRate(): Double = { + val rate = rates(idx) + idx = (idx + 1) % rates.size + rate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(nextRate()) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala index 93f920fdc7..0418d776ec 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -64,7 +64,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: " + "schedule receivers evenly when there are more receivers than executors") { - val receivers = (0 until 6).map(new DummyReceiver(_)) + val receivers = (0 until 6).map(new RateTestReceiver(_)) val executors = (10000 until 10003).map(port => s"localhost:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) val numReceiversOnExecutor = mutable.HashMap[String, Int]() @@ -79,7 +79,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: " + "schedule receivers evenly when there are more executors than receivers") { - val receivers = (0 until 3).map(new DummyReceiver(_)) + val receivers = (0 until 3).map(new RateTestReceiver(_)) val executors = (10000 until 10006).map(port => s"localhost:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) val numReceiversOnExecutor = mutable.HashMap[String, Int]() @@ -94,8 +94,8 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { } test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { - val receivers = (0 until 3).map(new DummyReceiver(_)) ++ - (3 until 6).map(new DummyReceiver(_, Some("localhost"))) + val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ + (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ (10003 until 10006).map(port => s"localhost2:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) @@ -121,7 +121,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { } test("scheduleReceivers: return empty scheduled executors if no executors") { - val receivers = (0 until 3).map(new DummyReceiver(_)) + val receivers = (0 until 3).map(new RateTestReceiver(_)) val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) scheduledExecutors.foreach { case (receiverId, executors) => assert(executors.isEmpty) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index b039233f36..aff8b53f75 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -43,6 +43,7 @@ class ReceiverTrackerSuite extends TestSuiteBase { ssc.addStreamingListener(ReceiverStartedWaiter) ssc.scheduler.listenerBus.start(ssc.sc) + SingletonTestRateReceiver.reset() val newRateLimit = 100L val inputDStream = new RateLimitInputDStream(ssc) @@ -62,36 +63,62 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } -/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ -private class RateLimitInputDStream(@transient ssc_ : StreamingContext) +/** + * An input DStream with a hard-coded receiver that gives access to internals for testing. + * + * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, + * or otherwise you may get {{{NotSerializableException}}} when trying to serialize + * the receiver. + * @see [[[SingletonDummyReceiver]]]. + */ +private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) { - override def getReceiver(): DummyReceiver = SingletonDummyReceiver + override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver def getCurrentRateLimit: Option[Long] = { invokeExecutorMethod.getCurrentRateLimit } + @volatile + var publishCalls = 0 + + override val rateController: Option[RateController] = { + Some(new RateController(id, new ConstantEstimator(100.0)) { + override def publish(rate: Long): Unit = { + publishCalls += 1 + } + }) + } + private def invokeExecutorMethod: ReceiverSupervisor = { val c = classOf[Receiver[_]] val ex = c.getDeclaredMethod("executor") ex.setAccessible(true) - ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor] + ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor] } } /** - * A Receiver as an object so we can read its rate limit. + * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when + * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being + * serialized when receivers are installed on executors. * * @note It's necessary to be a top-level object, or else serialization would create another * one on the executor side and we won't be able to read its rate limit. */ -private object SingletonDummyReceiver extends DummyReceiver(0) +private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) { + + /** Reset the object to be usable in another test. */ + def reset(): Unit = { + executor_ = null + } +} /** * Dummy receiver implementation */ -private class DummyReceiver(receiverId: Int, host: Option[String] = None) +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { setReceiverId(receiverId) -- cgit v1.2.3