diff options
3 files changed, 100 insertions, 9 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index dd32ad5ad8..0148cb51c6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -72,8 +72,10 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: /** * Stop the timer, and return the last time the callback was made. - * interruptTimer = true will interrupt the callback + * - interruptTimer = true will interrupt the callback * if it is in progress (not guaranteed to give correct time in this case). + * - interruptTimer = false guarantees that there will be at least one callback after `stop` has + * been called. */ def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { @@ -87,18 +89,23 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: prevTime } + private def triggerActionForNextInterval(): Unit = { + clock.waitTillTime(nextTime) + callback(nextTime) + prevTime = nextTime + nextTime += period + logDebug("Callback for " + name + " called at time " + prevTime) + } + /** * Repeatedly call the callback every interval. */ private def loop() { try { while (!stopped) { - clock.waitTillTime(nextTime) - callback(nextTime) - prevTime = nextTime - nextTime += period - logDebug("Callback for " + name + " called at time " + prevTime) + triggerActionForNextInterval() } + triggerActionForNextInterval() } catch { case e: InterruptedException => } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index a38cc603f2..2f11b255f1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -184,9 +184,10 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { // Verify that the final data is present in the final generated block and // pushed before complete stop assert(blockGenerator.isStopped() === false) // generator has not stopped yet - clock.advance(blockIntervalMs) // force block generation - failAfter(1 second) { - thread.join() + eventually(timeout(10 seconds), interval(10 milliseconds)) { + // Keep calling `advance` to avoid blocking forever in `clock.waitTillTime` + clock.advance(blockIntervalMs) + assert(thread.isAlive === false) } assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped assert(listener.pushedData === data, "All data not pushed by stop()") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala new file mode 100644 index 0000000000..0544972d95 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import scala.collection.mutable +import scala.concurrent.duration._ + +import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ManualClock + +class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { + + test("basic") { + val clock = new ManualClock() + val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val timer = new RecurringTimer(clock, 100, time => { + results += time + }, "RecurringTimerSuite-basic") + timer.start(0) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L)) + } + clock.advance(100) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L, 100L)) + } + clock.advance(200) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L, 100L, 200L, 300L)) + } + assert(timer.stop(interruptTimer = true) === 300L) + } + + test("SPARK-10224: call 'callback' after stopping") { + val clock = new ManualClock() + val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val timer = new RecurringTimer(clock, 100, time => { + results += time + }, "RecurringTimerSuite-SPARK-10224") + timer.start(0) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L)) + } + @volatile var lastTime = -1L + // Now RecurringTimer is waiting for the next interval + val thread = new Thread { + override def run(): Unit = { + lastTime = timer.stop(interruptTimer = false) + } + } + thread.start() + val stopped = PrivateMethod[RecurringTimer]('stopped) + // Make sure the `stopped` field has been changed + eventually(timeout(10.seconds), interval(10.millis)) { + assert(timer.invokePrivate(stopped()) === true) + } + clock.advance(200) + // When RecurringTimer is awake from clock.waitTillTime, it will call `callback` once. + // Then it will find `stopped` is true and exit the loop, but it should call `callback` again + // before exiting its internal thread. + thread.join() + assert(results === Seq(0L, 100L, 200L)) + assert(lastTime === 200L) + } +} |