aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2015-08-06 14:35:30 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-08-06 14:35:30 -0700
commit0a078303d08ad2bb92b9a8a6969563d75b512290 (patch)
tree3dd33e8c34634c6797da37561e230faaadf2b395 /streaming
parent21fdfd7d6f89adbd37066c169e6ba9ccd337683e (diff)
downloadspark-0a078303d08ad2bb92b9a8a6969563d75b512290.tar.gz
spark-0a078303d08ad2bb92b9a8a6969563d75b512290.tar.bz2
spark-0a078303d08ad2bb92b9a8a6969563d75b512290.zip
[SPARK-9556] [SPARK-9619] [SPARK-9624] [STREAMING] Make BlockGenerator more robust and make all BlockGenerators subscribe to rate limit updates
In some receivers, instead of using the default `BlockGenerator` in `ReceiverSupervisorImpl`, custom generator with their custom listeners are used for reliability (see [`ReliableKafkaReceiver`](https://github.com/apache/spark/blob/master/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala#L99) and [updated `KinesisReceiver`](https://github.com/apache/spark/pull/7825/files)). These custom generators do not receive rate updates. This PR modifies the code to allow custom `BlockGenerator`s to be created through the `ReceiverSupervisorImpl` so that they can be kept track and rate updates can be applied. In the process, I did some simplification, and de-flaki-fication of some rate controller related tests. In particular. - Renamed `Receiver.executor` to `Receiver.supervisor` (to match `ReceiverSupervisor`) - Made `RateControllerSuite` faster (by increasing batch interval) and less flaky - Changed a few internal API to return the current rate of block generators as Long instead of Option\[Long\] (was inconsistent at places). - Updated existing `ReceiverTrackerSuite` to test that custom block generators get rate updates as well. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #7913 from tdas/SPARK-9556 and squashes the following commits: 41d4461 [Tathagata Das] fix scala style eb9fd59 [Tathagata Das] Updated kinesis receiver d24994d [Tathagata Das] Updated BlockGeneratorSuite to use manual clock in BlockGenerator d70608b [Tathagata Das] Updated BlockGenerator with states and proper synchronization f6bd47e [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9556 31da173 [Tathagata Das] Fix bug 12116df [Tathagata Das] Add BlockGeneratorSuite 74bd069 [Tathagata Das] Fix style 989bb5c [Tathagata Das] Made BlockGenerator fail is used after stop, and added better unit tests for it 3ff618c [Tathagata Das] Fix test b40eff8 [Tathagata Das] slight refactoring f0df0f1 [Tathagata Das] Scala style fixes 51759cb [Tathagata Das] Refactored rate controller tests and added the ability to update rate of any custom block generator
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala131
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala52
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala27
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala33
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala16
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala31
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala253
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala64
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala129
11 files changed, 531 insertions, 216 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
index cd309788a7..7ec74016a1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
@@ -144,7 +144,7 @@ private[streaming] class ActorReceiver[T: ClassTag](
receiverSupervisorStrategy: SupervisorStrategy
) extends Receiver[T](storageLevel) with Logging {
- protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor),
+ protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor),
"Supervisor" + streamId)
class Supervisor extends Actor {
@@ -191,11 +191,11 @@ private[streaming] class ActorReceiver[T: ClassTag](
}
def onStart(): Unit = {
- supervisor
- logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
+ actorSupervisor
+ logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path)
}
def onStop(): Unit = {
- supervisor ! PoisonPill
+ actorSupervisor ! PoisonPill
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index 92b51ce392..794dece370 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -21,10 +21,10 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit}
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.util.RecurringTimer
-import org.apache.spark.util.SystemClock
+import org.apache.spark.util.{Clock, SystemClock}
/** Listener object for BlockGenerator events */
private[streaming] trait BlockGeneratorListener {
@@ -69,16 +69,35 @@ private[streaming] trait BlockGeneratorListener {
* named blocks at regular intervals. This class starts two threads,
* one to periodically start a new batch and prepare the previous batch of as a block,
* the other to push the blocks into the block manager.
+ *
+ * Note: Do not create BlockGenerator instances directly inside receivers. Use
+ * `ReceiverSupervisor.createBlockGenerator` to create a BlockGenerator and use it.
*/
private[streaming] class BlockGenerator(
listener: BlockGeneratorListener,
receiverId: Int,
- conf: SparkConf
+ conf: SparkConf,
+ clock: Clock = new SystemClock()
) extends RateLimiter(conf) with Logging {
private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any])
- private val clock = new SystemClock()
+ /**
+ * The BlockGenerator can be in 5 possible states, in the order as follows.
+ * - Initialized: Nothing has been started
+ * - Active: start() has been called, and it is generating blocks on added data.
+ * - StoppedAddingData: stop() has been called, the adding of data has been stopped,
+ * but blocks are still being generated and pushed.
+ * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but
+ * they are still being pushed.
+ * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed.
+ */
+ private object GeneratorState extends Enumeration {
+ type GeneratorState = Value
+ val Initialized, Active, StoppedAddingData, StoppedGeneratingBlocks, StoppedAll = Value
+ }
+ import GeneratorState._
+
private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms")
require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value")
@@ -89,59 +108,100 @@ private[streaming] class BlockGenerator(
private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
@volatile private var currentBuffer = new ArrayBuffer[Any]
- @volatile private var stopped = false
+ @volatile private var state = Initialized
/** Start block generating and pushing threads. */
- def start() {
- blockIntervalTimer.start()
- blockPushingThread.start()
- logInfo("Started BlockGenerator")
+ def start(): Unit = synchronized {
+ if (state == Initialized) {
+ state = Active
+ blockIntervalTimer.start()
+ blockPushingThread.start()
+ logInfo("Started BlockGenerator")
+ } else {
+ throw new SparkException(
+ s"Cannot start BlockGenerator as its not in the Initialized state [state = $state]")
+ }
}
- /** Stop all threads. */
- def stop() {
+ /**
+ * Stop everything in the right order such that all the data added is pushed out correctly.
+ * - First, stop adding data to the current buffer.
+ * - Second, stop generating blocks.
+ * - Finally, wait for queue of to-be-pushed blocks to be drained.
+ */
+ def stop(): Unit = {
+ // Set the state to stop adding data
+ synchronized {
+ if (state == Active) {
+ state = StoppedAddingData
+ } else {
+ logWarning(s"Cannot stop BlockGenerator as its not in the Active state [state = $state]")
+ return
+ }
+ }
+
+ // Stop generating blocks and set the state for block pushing thread to start draining the queue
logInfo("Stopping BlockGenerator")
blockIntervalTimer.stop(interruptTimer = false)
- stopped = true
- logInfo("Waiting for block pushing thread")
+ synchronized { state = StoppedGeneratingBlocks }
+
+ // Wait for the queue to drain and mark generated as stopped
+ logInfo("Waiting for block pushing thread to terminate")
blockPushingThread.join()
+ synchronized { state = StoppedAll }
logInfo("Stopped BlockGenerator")
}
/**
- * Push a single data item into the buffer. All received data items
- * will be periodically pushed into BlockManager.
+ * Push a single data item into the buffer.
*/
- def addData (data: Any): Unit = synchronized {
- waitToPush()
- currentBuffer += data
+ def addData(data: Any): Unit = synchronized {
+ if (state == Active) {
+ waitToPush()
+ currentBuffer += data
+ } else {
+ throw new SparkException(
+ "Cannot add data as BlockGenerator has not been started or has been stopped")
+ }
}
/**
* Push a single data item into the buffer. After buffering the data, the
- * `BlockGeneratorListener.onAddData` callback will be called. All received data items
- * will be periodically pushed into BlockManager.
+ * `BlockGeneratorListener.onAddData` callback will be called.
*/
def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized {
- waitToPush()
- currentBuffer += data
- listener.onAddData(data, metadata)
+ if (state == Active) {
+ waitToPush()
+ currentBuffer += data
+ listener.onAddData(data, metadata)
+ } else {
+ throw new SparkException(
+ "Cannot add data as BlockGenerator has not been started or has been stopped")
+ }
}
/**
* Push multiple data items into the buffer. After buffering the data, the
- * `BlockGeneratorListener.onAddData` callback will be called. All received data items
- * will be periodically pushed into BlockManager. Note that all the data items is guaranteed
- * to be present in a single block.
+ * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items
+ * are atomically added to the buffer, and are hence guaranteed to be present in a single block.
*/
def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized {
- dataIterator.foreach { data =>
- waitToPush()
- currentBuffer += data
+ if (state == Active) {
+ dataIterator.foreach { data =>
+ waitToPush()
+ currentBuffer += data
+ }
+ listener.onAddData(dataIterator, metadata)
+ } else {
+ throw new SparkException(
+ "Cannot add data as BlockGenerator has not been started or has been stopped")
}
- listener.onAddData(dataIterator, metadata)
}
+ def isActive(): Boolean = state == Active
+
+ def isStopped(): Boolean = state == StoppedAll
+
/** Change the buffer to which single records are added to. */
private def updateCurrentBuffer(time: Long): Unit = synchronized {
try {
@@ -165,18 +225,21 @@ private[streaming] class BlockGenerator(
/** Keep pushing blocks to the BlockManager. */
private def keepPushingBlocks() {
logInfo("Started block pushing thread")
+
+ def isGeneratingBlocks = synchronized { state == Active || state == StoppedAddingData }
try {
- while (!stopped) {
- Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match {
+ while (isGeneratingBlocks) {
+ Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match {
case Some(block) => pushBlock(block)
case None =>
}
}
- // Push out the blocks that are still left
+
+ // At this point, state is StoppedGeneratingBlock. So drain the queue of to-be-pushed blocks.
logInfo("Pushing out the last " + blocksForPushing.size() + " blocks")
while (!blocksForPushing.isEmpty) {
- logDebug("Getting block ")
val block = blocksForPushing.take()
+ logDebug(s"Pushing block $block")
pushBlock(block)
logInfo("Blocks left to push " + blocksForPushing.size())
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
index f663def4c0..bca1fbc8fd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
@@ -45,8 +45,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging {
/**
* Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}.
*/
- def getCurrentLimit: Long =
- rateLimiter.getRate.toLong
+ def getCurrentLimit: Long = rateLimiter.getRate.toLong
/**
* Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
index 7504fa44d9..554aae0117 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
@@ -116,12 +116,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* being pushed into Spark's memory.
*/
def store(dataItem: T) {
- executor.pushSingle(dataItem)
+ supervisor.pushSingle(dataItem)
}
/** Store an ArrayBuffer of received data as a data block into Spark's memory. */
def store(dataBuffer: ArrayBuffer[T]) {
- executor.pushArrayBuffer(dataBuffer, None, None)
+ supervisor.pushArrayBuffer(dataBuffer, None, None)
}
/**
@@ -130,12 +130,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* for being used in the corresponding InputDStream.
*/
def store(dataBuffer: ArrayBuffer[T], metadata: Any) {
- executor.pushArrayBuffer(dataBuffer, Some(metadata), None)
+ supervisor.pushArrayBuffer(dataBuffer, Some(metadata), None)
}
/** Store an iterator of received data as a data block into Spark's memory. */
def store(dataIterator: Iterator[T]) {
- executor.pushIterator(dataIterator, None, None)
+ supervisor.pushIterator(dataIterator, None, None)
}
/**
@@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* for being used in the corresponding InputDStream.
*/
def store(dataIterator: java.util.Iterator[T], metadata: Any) {
- executor.pushIterator(dataIterator, Some(metadata), None)
+ supervisor.pushIterator(dataIterator, Some(metadata), None)
}
/** Store an iterator of received data as a data block into Spark's memory. */
def store(dataIterator: java.util.Iterator[T]) {
- executor.pushIterator(dataIterator, None, None)
+ supervisor.pushIterator(dataIterator, None, None)
}
/**
@@ -158,7 +158,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* for being used in the corresponding InputDStream.
*/
def store(dataIterator: Iterator[T], metadata: Any) {
- executor.pushIterator(dataIterator, Some(metadata), None)
+ supervisor.pushIterator(dataIterator, Some(metadata), None)
}
/**
@@ -167,7 +167,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* that Spark is configured to use.
*/
def store(bytes: ByteBuffer) {
- executor.pushBytes(bytes, None, None)
+ supervisor.pushBytes(bytes, None, None)
}
/**
@@ -176,12 +176,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* for being used in the corresponding InputDStream.
*/
def store(bytes: ByteBuffer, metadata: Any) {
- executor.pushBytes(bytes, Some(metadata), None)
+ supervisor.pushBytes(bytes, Some(metadata), None)
}
/** Report exceptions in receiving data. */
def reportError(message: String, throwable: Throwable) {
- executor.reportError(message, throwable)
+ supervisor.reportError(message, throwable)
}
/**
@@ -193,7 +193,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* The `message` will be reported to the driver.
*/
def restart(message: String) {
- executor.restartReceiver(message)
+ supervisor.restartReceiver(message)
}
/**
@@ -205,7 +205,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* The `message` and `exception` will be reported to the driver.
*/
def restart(message: String, error: Throwable) {
- executor.restartReceiver(message, Some(error))
+ supervisor.restartReceiver(message, Some(error))
}
/**
@@ -215,22 +215,22 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* in a background thread.
*/
def restart(message: String, error: Throwable, millisecond: Int) {
- executor.restartReceiver(message, Some(error), millisecond)
+ supervisor.restartReceiver(message, Some(error), millisecond)
}
/** Stop the receiver completely. */
def stop(message: String) {
- executor.stop(message, None)
+ supervisor.stop(message, None)
}
/** Stop the receiver completely due to an exception */
def stop(message: String, error: Throwable) {
- executor.stop(message, Some(error))
+ supervisor.stop(message, Some(error))
}
/** Check if the receiver has started or not. */
def isStarted(): Boolean = {
- executor.isReceiverStarted()
+ supervisor.isReceiverStarted()
}
/**
@@ -238,7 +238,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* the receiving of data should be stopped.
*/
def isStopped(): Boolean = {
- executor.isReceiverStopped()
+ supervisor.isReceiverStopped()
}
/**
@@ -257,7 +257,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
private var id: Int = -1
/** Handler object that runs the receiver. This is instantiated lazily in the worker. */
- private[streaming] var executor_ : ReceiverSupervisor = null
+ @transient private var _supervisor : ReceiverSupervisor = null
/** Set the ID of the DStream that this receiver is associated with. */
private[streaming] def setReceiverId(id_ : Int) {
@@ -265,15 +265,17 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
}
/** Attach Network Receiver executor to this receiver. */
- private[streaming] def attachExecutor(exec: ReceiverSupervisor) {
- assert(executor_ == null)
- executor_ = exec
+ private[streaming] def attachSupervisor(exec: ReceiverSupervisor) {
+ assert(_supervisor == null)
+ _supervisor = exec
}
- /** Get the attached executor. */
- private def executor: ReceiverSupervisor = {
- assert(executor_ != null, "Executor has not been attached to this receiver")
- executor_
+ /** Get the attached supervisor. */
+ private[streaming] def supervisor: ReceiverSupervisor = {
+ assert(_supervisor != null,
+ "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " +
+ "some computation in the receiver before the Receiver.onStart() has been called.")
+ _supervisor
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
index e98017a637..158d1ba2f1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
@@ -44,8 +44,8 @@ private[streaming] abstract class ReceiverSupervisor(
}
import ReceiverState._
- // Attach the executor to the receiver
- receiver.attachExecutor(this)
+ // Attach the supervisor to the receiver
+ receiver.attachSupervisor(this)
private val futureExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128))
@@ -60,7 +60,7 @@ private[streaming] abstract class ReceiverSupervisor(
private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000)
/** The current maximum rate limit for this receiver. */
- private[streaming] def getCurrentRateLimit: Option[Long] = None
+ private[streaming] def getCurrentRateLimit: Long = Long.MaxValue
/** Exception associated with the stopping of the receiver */
@volatile protected var stoppingError: Throwable = null
@@ -92,13 +92,30 @@ private[streaming] abstract class ReceiverSupervisor(
optionalBlockId: Option[StreamBlockId]
)
+ /**
+ * Create a custom [[BlockGenerator]] that the receiver implementation can directly control
+ * using their provided [[BlockGeneratorListener]].
+ *
+ * Note: Do not explicitly start or stop the `BlockGenerator`, the `ReceiverSupervisorImpl`
+ * will take care of it.
+ */
+ def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator
+
/** Report errors. */
def reportError(message: String, throwable: Throwable)
- /** Called when supervisor is started */
+ /**
+ * Called when supervisor is started.
+ * Note that this must be called before the receiver.onStart() is called to ensure
+ * things like [[BlockGenerator]]s are started before the receiver starts sending data.
+ */
protected def onStart() { }
- /** Called when supervisor is stopped */
+ /**
+ * Called when supervisor is stopped.
+ * Note that this must be called after the receiver.onStop() is called to ensure
+ * things like [[BlockGenerator]]s are cleaned up after the receiver stops sending data.
+ */
protected def onStop(message: String, error: Option[Throwable]) { }
/** Called when receiver is started. Return true if the driver accepts us */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 0d802f8354..59ef58d232 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.receiver
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import com.google.common.base.Throwables
@@ -81,15 +82,20 @@ private[streaming] class ReceiverSupervisorImpl(
cleanupOldBlocks(threshTime)
case UpdateRateLimit(eps) =>
logInfo(s"Received a new rate limit: $eps.")
- blockGenerator.updateRate(eps)
+ registeredBlockGenerators.foreach { bg =>
+ bg.updateRate(eps)
+ }
}
})
/** Unique block ids if one wants to add blocks directly */
private val newBlockId = new AtomicLong(System.currentTimeMillis())
+ private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator]
+ with mutable.SynchronizedBuffer[BlockGenerator]
+
/** Divides received data records into data blocks for pushing in BlockManager. */
- private val blockGenerator = new BlockGenerator(new BlockGeneratorListener {
+ private val defaultBlockGeneratorListener = new BlockGeneratorListener {
def onAddData(data: Any, metadata: Any): Unit = { }
def onGenerateBlock(blockId: StreamBlockId): Unit = { }
@@ -101,14 +107,15 @@ private[streaming] class ReceiverSupervisorImpl(
def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) {
pushArrayBuffer(arrayBuffer, None, Some(blockId))
}
- }, streamId, env.conf)
+ }
+ private val defaultBlockGenerator = createBlockGenerator(defaultBlockGeneratorListener)
- override private[streaming] def getCurrentRateLimit: Option[Long] =
- Some(blockGenerator.getCurrentLimit)
+ /** Get the current rate limit of the default block generator */
+ override private[streaming] def getCurrentRateLimit: Long = defaultBlockGenerator.getCurrentLimit
/** Push a single record of received data into block generator. */
def pushSingle(data: Any) {
- blockGenerator.addData(data)
+ defaultBlockGenerator.addData(data)
}
/** Store an ArrayBuffer of received data as a data block into Spark's memory. */
@@ -162,11 +169,11 @@ private[streaming] class ReceiverSupervisorImpl(
}
override protected def onStart() {
- blockGenerator.start()
+ registeredBlockGenerators.foreach { _.start() }
}
override protected def onStop(message: String, error: Option[Throwable]) {
- blockGenerator.stop()
+ registeredBlockGenerators.foreach { _.stop() }
env.rpcEnv.stop(endpoint)
}
@@ -183,6 +190,16 @@ private[streaming] class ReceiverSupervisorImpl(
logInfo("Stopped receiver " + streamId)
}
+ override def createBlockGenerator(
+ blockGeneratorListener: BlockGeneratorListener): BlockGenerator = {
+ // Cleanup BlockGenerators that have already been stopped
+ registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() }
+
+ val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf)
+ registeredBlockGenerators += newBlockGenerator
+ newBlockGenerator
+ }
+
/** Generate new block ID */
private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 67c2d90094..1bba7a143e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.streaming
import java.io.File
-import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.reflect.ClassTag
import com.google.common.base.Charsets
@@ -33,7 +33,7 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
-import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver}
+import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver}
import org.apache.spark.util.{Clock, ManualClock, Utils}
/**
@@ -397,26 +397,24 @@ class CheckpointSuite extends TestSuiteBase {
ssc = new StreamingContext(conf, batchDuration)
ssc.checkpoint(checkpointDir)
- val dstream = new RateLimitInputDStream(ssc) {
+ val dstream = new RateTestInputDStream(ssc) {
override val rateController =
- Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+ Some(new ReceiverRateController(id, new ConstantEstimator(200)))
}
- SingletonTestRateReceiver.reset()
val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2))
output.register()
runStreams(ssc, 5, 5)
- SingletonTestRateReceiver.reset()
ssc = new StreamingContext(checkpointDir)
ssc.start()
val outputNew = advanceTimeWithRealDelay(ssc, 2)
- eventually(timeout(5.seconds)) {
- assert(dstream.getCurrentRateLimit === Some(200))
+ eventually(timeout(10.seconds)) {
+ assert(RateTestReceiver.getActive().nonEmpty)
+ assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 200)
}
ssc.stop()
- ssc = null
}
// This tests whether file input stream remembers what files were seen before
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index 13b4d17c86..01279b34f7 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -129,32 +129,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
}
}
- test("block generator") {
- val blockGeneratorListener = new FakeBlockGeneratorListener
- val blockIntervalMs = 200
- val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms")
- val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf)
- val expectedBlocks = 5
- val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2)
- val generatedData = new ArrayBuffer[Int]
-
- // Generate blocks
- val startTime = System.currentTimeMillis()
- blockGenerator.start()
- var count = 0
- while(System.currentTimeMillis - startTime < waitTime) {
- blockGenerator.addData(count)
- generatedData += count
- count += 1
- Thread.sleep(10)
- }
- blockGenerator.stop()
-
- val recordedData = blockGeneratorListener.arrayBuffers.flatten
- assert(blockGeneratorListener.arrayBuffers.size > 0)
- assert(recordedData.toSet === generatedData.toSet)
- }
-
ignore("block generator throttling") {
val blockGeneratorListener = new FakeBlockGeneratorListener
val blockIntervalMs = 100
@@ -348,6 +322,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
}
override protected def onReceiverStart(): Boolean = true
+
+ override def createBlockGenerator(
+ blockGeneratorListener: BlockGeneratorListener): BlockGenerator = {
+ null
+ }
}
/**
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
new file mode 100644
index 0000000000..a38cc603f2
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.receiver
+
+import scala.collection.mutable
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.Matchers._
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.util.ManualClock
+import org.apache.spark.{SparkException, SparkConf, SparkFunSuite}
+
+class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter {
+
+ private val blockIntervalMs = 10
+ private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms")
+ @volatile private var blockGenerator: BlockGenerator = null
+
+ after {
+ if (blockGenerator != null) {
+ blockGenerator.stop()
+ }
+ }
+
+ test("block generation and data callbacks") {
+ val listener = new TestBlockGeneratorListener
+ val clock = new ManualClock()
+
+ require(blockIntervalMs > 5)
+ require(listener.onAddDataCalled === false)
+ require(listener.onGenerateBlockCalled === false)
+ require(listener.onPushBlockCalled === false)
+
+ // Verify that creating the generator does not start it
+ blockGenerator = new BlockGenerator(listener, 0, conf, clock)
+ assert(blockGenerator.isActive() === false, "block generator active before start()")
+ assert(blockGenerator.isStopped() === false, "block generator stopped before start()")
+ assert(listener.onAddDataCalled === false)
+ assert(listener.onGenerateBlockCalled === false)
+ assert(listener.onPushBlockCalled === false)
+
+ // Verify start marks the generator active, but does not call the callbacks
+ blockGenerator.start()
+ assert(blockGenerator.isActive() === true, "block generator active after start()")
+ assert(blockGenerator.isStopped() === false, "block generator stopped after start()")
+ withClue("callbacks called before adding data") {
+ assert(listener.onAddDataCalled === false)
+ assert(listener.onGenerateBlockCalled === false)
+ assert(listener.onPushBlockCalled === false)
+ }
+
+ // Verify whether addData() adds data that is present in generated blocks
+ val data1 = 1 to 10
+ data1.foreach { blockGenerator.addData _ }
+ withClue("callbacks called on adding data without metadata and without block generation") {
+ assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback()
+ assert(listener.onGenerateBlockCalled === false)
+ assert(listener.onPushBlockCalled === false)
+ }
+ clock.advance(blockIntervalMs) // advance clock to generate blocks
+ withClue("blocks not generated or pushed") {
+ eventually(timeout(1 second)) {
+ assert(listener.onGenerateBlockCalled === true)
+ assert(listener.onPushBlockCalled === true)
+ }
+ }
+ listener.pushedData should contain theSameElementsInOrderAs (data1)
+ assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback()
+
+ // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly
+ val data2 = 11 to 20
+ val metadata2 = data2.map { _.toString }
+ data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) }
+ assert(listener.onAddDataCalled === true)
+ listener.addedData should contain theSameElementsInOrderAs (data2)
+ listener.addedMetadata should contain theSameElementsInOrderAs (metadata2)
+ clock.advance(blockIntervalMs) // advance clock to generate blocks
+ eventually(timeout(1 second)) {
+ listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2)
+ }
+
+ // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly
+ val data3 = 21 to 30
+ val metadata3 = "metadata"
+ blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3)
+ listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3)
+ clock.advance(blockIntervalMs) // advance clock to generate blocks
+ eventually(timeout(1 second)) {
+ listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3)
+ }
+
+ // Stop the block generator by starting the stop on a different thread and
+ // then advancing the manual clock for the stopping to proceed.
+ val thread = stopBlockGenerator(blockGenerator)
+ eventually(timeout(1 second), interval(10 milliseconds)) {
+ clock.advance(blockIntervalMs)
+ assert(blockGenerator.isStopped() === true)
+ }
+ thread.join()
+
+ // Verify that the generator cannot be used any more
+ intercept[SparkException] {
+ blockGenerator.addData(1)
+ }
+ intercept[SparkException] {
+ blockGenerator.addDataWithCallback(1, 1)
+ }
+ intercept[SparkException] {
+ blockGenerator.addMultipleDataWithCallback(Iterator(1), 1)
+ }
+ intercept[SparkException] {
+ blockGenerator.start()
+ }
+ blockGenerator.stop() // Calling stop again should be fine
+ }
+
+ test("stop ensures correct shutdown") {
+ val listener = new TestBlockGeneratorListener
+ val clock = new ManualClock()
+ blockGenerator = new BlockGenerator(listener, 0, conf, clock)
+ require(listener.onGenerateBlockCalled === false)
+ blockGenerator.start()
+ assert(blockGenerator.isActive() === true, "block generator")
+ assert(blockGenerator.isStopped() === false)
+
+ val data = 1 to 1000
+ data.foreach { blockGenerator.addData _ }
+
+ // Verify that stop() shutdowns everything in the right order
+ // - First, stop receiving new data
+ // - Second, wait for final block with all buffered data to be generated
+ // - Finally, wait for all blocks to be pushed
+ clock.advance(1) // to make sure that the timer for another interval to complete
+ val thread = stopBlockGenerator(blockGenerator)
+ eventually(timeout(1 second), interval(10 milliseconds)) {
+ assert(blockGenerator.isActive() === false)
+ }
+ assert(blockGenerator.isStopped() === false)
+
+ // Verify that data cannot be added
+ intercept[SparkException] {
+ blockGenerator.addData(1)
+ }
+ intercept[SparkException] {
+ blockGenerator.addDataWithCallback(1, null)
+ }
+ intercept[SparkException] {
+ blockGenerator.addMultipleDataWithCallback(Iterator(1), null)
+ }
+
+ // Verify that stop() stays blocked until another block containing all the data is generated
+ // This intercept always succeeds, as the body either will either throw a timeout exception
+ // (expected as stop() should never complete) or a SparkException (unexpected as stop()
+ // completed and thread terminated).
+ val exception = intercept[Exception] {
+ failAfter(200 milliseconds) {
+ thread.join()
+ throw new SparkException(
+ "BlockGenerator.stop() completed before generating timer was stopped")
+ }
+ }
+ exception should not be a [SparkException]
+
+
+ // Verify that the final data is present in the final generated block and
+ // pushed before complete stop
+ assert(blockGenerator.isStopped() === false) // generator has not stopped yet
+ clock.advance(blockIntervalMs) // force block generation
+ failAfter(1 second) {
+ thread.join()
+ }
+ assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped
+ assert(listener.pushedData === data, "All data not pushed by stop()")
+ }
+
+ test("block push errors are reported") {
+ val listener = new TestBlockGeneratorListener {
+ @volatile var errorReported = false
+ override def onPushBlock(
+ blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
+ throw new SparkException("test")
+ }
+ override def onError(message: String, throwable: Throwable): Unit = {
+ errorReported = true
+ }
+ }
+ blockGenerator = new BlockGenerator(listener, 0, conf)
+ blockGenerator.start()
+ assert(listener.errorReported === false)
+ blockGenerator.addData(1)
+ eventually(timeout(1 second), interval(10 milliseconds)) {
+ assert(listener.errorReported === true)
+ }
+ blockGenerator.stop()
+ }
+
+ /**
+ * Helper method to stop the block generator with manual clock in a different thread,
+ * so that the main thread can advance the clock that allows the stopping to proceed.
+ */
+ private def stopBlockGenerator(blockGenerator: BlockGenerator): Thread = {
+ val thread = new Thread() {
+ override def run(): Unit = {
+ blockGenerator.stop()
+ }
+ }
+ thread.start()
+ thread
+ }
+
+ /** A listener for BlockGenerator that records the data in the callbacks */
+ private class TestBlockGeneratorListener extends BlockGeneratorListener {
+ val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any]
+ val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any]
+ val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any]
+ @volatile var onGenerateBlockCalled = false
+ @volatile var onAddDataCalled = false
+ @volatile var onPushBlockCalled = false
+
+ override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
+ pushedData ++= arrayBuffer
+ onPushBlockCalled = true
+ }
+ override def onError(message: String, throwable: Throwable): Unit = {}
+ override def onGenerateBlock(blockId: StreamBlockId): Unit = {
+ onGenerateBlockCalled = true
+ }
+ override def onAddData(data: Any, metadata: Any): Unit = {
+ addedData += data
+ addedMetadata += metadata
+ onAddDataCalled = true
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
index 921da773f6..1eb52b7029 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
@@ -18,10 +18,7 @@
package org.apache.spark.streaming.scheduler
import scala.collection.mutable
-import scala.reflect.ClassTag
-import scala.util.control.NonFatal
-import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
@@ -32,72 +29,63 @@ class RateControllerSuite extends TestSuiteBase {
override def useManualClock: Boolean = false
- test("rate controller publishes updates") {
+ override def batchDuration: Duration = Milliseconds(50)
+
+ test("RateController - rate controller publishes updates after batches complete") {
val ssc = new StreamingContext(conf, batchDuration)
withStreamingContext(ssc) { ssc =>
- val dstream = new RateLimitInputDStream(ssc)
+ val dstream = new RateTestInputDStream(ssc)
dstream.register()
ssc.start()
eventually(timeout(10.seconds)) {
- assert(dstream.publishCalls > 0)
+ assert(dstream.publishedRates > 0)
}
}
}
- test("publish rates reach receivers") {
+ test("ReceiverRateController - published rates reach receivers") {
val ssc = new StreamingContext(conf, batchDuration)
withStreamingContext(ssc) { ssc =>
- val dstream = new RateLimitInputDStream(ssc) {
+ val estimator = new ConstantEstimator(100)
+ val dstream = new RateTestInputDStream(ssc) {
override val rateController =
- Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+ Some(new ReceiverRateController(id, estimator))
}
dstream.register()
- SingletonTestRateReceiver.reset()
ssc.start()
- eventually(timeout(10.seconds)) {
- assert(dstream.getCurrentRateLimit === Some(200))
+ // Wait for receiver to start
+ eventually(timeout(5.seconds)) {
+ RateTestReceiver.getActive().nonEmpty
}
- }
- }
- test("multiple publish rates reach receivers") {
- val ssc = new StreamingContext(conf, batchDuration)
- withStreamingContext(ssc) { ssc =>
- val rates = Seq(100L, 200L, 300L)
-
- val dstream = new RateLimitInputDStream(ssc) {
- override val rateController =
- Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*)))
+ // Update rate in the estimator and verify whether the rate was published to the receiver
+ def updateRateAndVerify(rate: Long): Unit = {
+ estimator.updateRate(rate)
+ eventually(timeout(5.seconds)) {
+ assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === rate)
+ }
}
- SingletonTestRateReceiver.reset()
- dstream.register()
-
- val observedRates = mutable.HashSet.empty[Long]
- ssc.start()
- eventually(timeout(20.seconds)) {
- dstream.getCurrentRateLimit.foreach(observedRates += _)
- // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver
- observedRates should contain theSameElementsAs (rates :+ Long.MaxValue)
+ // Verify multiple rate update
+ Seq(100, 200, 300).foreach { rate =>
+ updateRateAndVerify(rate)
}
}
}
}
-private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator {
- private var idx: Int = 0
+private[streaming] class ConstantEstimator(@volatile private var rate: Long)
+ extends RateEstimator {
- private def nextRate(): Double = {
- val rate = rates(idx)
- idx = (idx + 1) % rates.size
- rate
+ def updateRate(newRate: Long): Unit = {
+ rate = newRate
}
def compute(
time: Long,
elements: Long,
processingDelay: Long,
- schedulingDelay: Long): Option[Double] = Some(nextRate())
+ schedulingDelay: Long): Option[Double] = Some(rate)
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index afad5f16db..dd292ba4dd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -17,48 +17,43 @@
package org.apache.spark.streaming.scheduler
+import scala.collection.mutable.ArrayBuffer
+
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.SparkConf
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
-import org.apache.spark.streaming.receiver._
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.receiver._
/** Testsuite for receiver scheduling */
class ReceiverTrackerSuite extends TestSuiteBase {
- val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test")
-
- test("Receiver tracker - propagates rate limit") {
- withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc =>
- object ReceiverStartedWaiter extends StreamingListener {
- @volatile
- var started = false
-
- override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = {
- started = true
- }
- }
- ssc.addStreamingListener(ReceiverStartedWaiter)
+ test("send rate update to receivers") {
+ withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc =>
ssc.scheduler.listenerBus.start(ssc.sc)
- SingletonTestRateReceiver.reset()
val newRateLimit = 100L
- val inputDStream = new RateLimitInputDStream(ssc)
+ val inputDStream = new RateTestInputDStream(ssc)
val tracker = new ReceiverTracker(ssc)
tracker.start()
try {
// we wait until the Receiver has registered with the tracker,
// otherwise our rate update is lost
eventually(timeout(5 seconds)) {
- assert(ReceiverStartedWaiter.started)
+ assert(RateTestReceiver.getActive().nonEmpty)
}
+
+
+ // Verify that the rate of the block generator in the receiver get updated
+ val activeReceiver = RateTestReceiver.getActive().get
tracker.sendRateUpdate(inputDStream.id, newRateLimit)
- // this is an async message, we need to wait a bit for it to be processed
- eventually(timeout(3 seconds)) {
- assert(inputDStream.getCurrentRateLimit.get === newRateLimit)
+ eventually(timeout(5 seconds)) {
+ assert(activeReceiver.getDefaultBlockGeneratorRateLimit() === newRateLimit,
+ "default block generator did not receive rate update")
+ assert(activeReceiver.getCustomBlockGeneratorRateLimit() === newRateLimit,
+ "other block generator did not receive rate update")
}
} finally {
tracker.stop(false)
@@ -67,69 +62,73 @@ class ReceiverTrackerSuite extends TestSuiteBase {
}
}
-/**
- * An input DStream with a hard-coded receiver that gives access to internals for testing.
- *
- * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test,
- * or otherwise you may get {{{NotSerializableException}}} when trying to serialize
- * the receiver.
- * @see [[[SingletonDummyReceiver]]].
- */
-private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext)
+/** An input DStream with for testing rate controlling */
+private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext)
extends ReceiverInputDStream[Int](ssc_) {
- override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver
-
- def getCurrentRateLimit: Option[Long] = {
- invokeExecutorMethod.getCurrentRateLimit
- }
+ override def getReceiver(): Receiver[Int] = new RateTestReceiver(id)
@volatile
- var publishCalls = 0
+ var publishedRates = 0
override val rateController: Option[RateController] = {
- Some(new RateController(id, new ConstantEstimator(100.0)) {
+ Some(new RateController(id, new ConstantEstimator(100)) {
override def publish(rate: Long): Unit = {
- publishCalls += 1
+ publishedRates += 1
}
})
}
+}
- private def invokeExecutorMethod: ReceiverSupervisor = {
- val c = classOf[Receiver[_]]
- val ex = c.getDeclaredMethod("executor")
- ex.setAccessible(true)
- ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor]
+/** A receiver implementation for testing rate controlling */
+private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None)
+ extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+
+ private lazy val customBlockGenerator = supervisor.createBlockGenerator(
+ new BlockGeneratorListener {
+ override def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit = {}
+ override def onError(message: String, throwable: Throwable): Unit = {}
+ override def onGenerateBlock(blockId: StreamBlockId): Unit = {}
+ override def onAddData(data: Any, metadata: Any): Unit = {}
+ }
+ )
+
+ setReceiverId(receiverId)
+
+ override def onStart(): Unit = {
+ customBlockGenerator
+ RateTestReceiver.registerReceiver(this)
}
-}
-/**
- * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when
- * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being
- * serialized when receivers are installed on executors.
- *
- * @note It's necessary to be a top-level object, or else serialization would create another
- * one on the executor side and we won't be able to read its rate limit.
- */
-private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) {
+ override def onStop(): Unit = {
+ RateTestReceiver.deregisterReceiver()
+ }
+
+ override def preferredLocation: Option[String] = host
- /** Reset the object to be usable in another test. */
- def reset(): Unit = {
- executor_ = null
+ def getDefaultBlockGeneratorRateLimit(): Long = {
+ supervisor.getCurrentRateLimit
+ }
+
+ def getCustomBlockGeneratorRateLimit(): Long = {
+ customBlockGenerator.getCurrentLimit
}
}
/**
- * Dummy receiver implementation
+ * A helper object to RateTestReceiver that give access to the currently active RateTestReceiver
+ * instance.
*/
-private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None)
- extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+private[streaming] object RateTestReceiver {
+ @volatile private var activeReceiver: RateTestReceiver = null
- setReceiverId(receiverId)
-
- override def onStart(): Unit = {}
+ def registerReceiver(receiver: RateTestReceiver): Unit = {
+ activeReceiver = receiver
+ }
- override def onStop(): Unit = {}
+ def deregisterReceiver(): Unit = {
+ activeReceiver = null
+ }
- override def preferredLocation: Option[String] = host
+ def getActive(): Option[RateTestReceiver] = Option(activeReceiver)
}