aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-07-02 13:59:56 -0700
committerAndrew Or <andrew@databricks.com>2015-07-02 13:59:56 -0700
commitcd2035507891a7f426f6f45902d3b5f4fdbe88cf (patch)
treea03f61825371db98531f61c63ec9fc67dd2f42ae /core
parentfcbcba66c92871fe3936e5ca605017e9c2a2eb95 (diff)
downloadspark-cd2035507891a7f426f6f45902d3b5f4fdbe88cf.tar.gz
spark-cd2035507891a7f426f6f45902d3b5f4fdbe88cf.tar.bz2
spark-cd2035507891a7f426f6f45902d3b5f4fdbe88cf.zip
[SPARK-7835] Refactor HeartbeatReceiverSuite for coverage + cleanup
The existing test suite has a lot of duplicate code and doesn't even cover the most fundamental feature of the HeartbeatReceiver, which is expiring hosts that have not responded in a while. This introduces manual clocks in `HeartbeatReceiver` and makes it respond to heartbeats only for registered executors. A few internal messages are moved to `receiveAndReply` to increase determinism of the tests so we don't have to rely on flaky constructs like `eventually`. Author: Andrew Or <andrew@databricks.com> Closes #7173 from andrewor14/heartbeat-receiver-tests and squashes the following commits: 4a903d6 [Andrew Or] Increase HeartReceiverSuite coverage and clean up
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala89
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala161
3 files changed, 191 insertions, 61 deletions
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 6909015ff6..221b1dab43 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -24,8 +24,8 @@ import scala.collection.mutable
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.scheduler._
+import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
@@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet
private[spark] case object ExpireDeadHosts
+private case class ExecutorRegistered(executorId: String)
+
+private case class ExecutorRemoved(executorId: String)
+
private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
/**
* Lives in the driver to receive heartbeats from executors..
*/
-private[spark] class HeartbeatReceiver(sc: SparkContext)
- extends ThreadSafeRpcEndpoint with Logging {
+private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
+ extends ThreadSafeRpcEndpoint with SparkListener with Logging {
+
+ def this(sc: SparkContext) {
+ this(sc, new SystemClock)
+ }
+
+ sc.addSparkListener(this)
override val rpcEnv: RpcEnv = sc.env.rpcEnv
@@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
override def onStart(): Unit = {
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
- Option(self).foreach(_.send(ExpireDeadHosts))
+ Option(self).foreach(_.ask[Boolean](ExpireDeadHosts))
}
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
}
- override def receive: PartialFunction[Any, Unit] = {
- case ExpireDeadHosts =>
- expireDeadHosts()
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+
+ // Messages sent and received locally
+ case ExecutorRegistered(executorId) =>
+ executorLastSeen(executorId) = clock.getTimeMillis()
+ context.reply(true)
+ case ExecutorRemoved(executorId) =>
+ executorLastSeen.remove(executorId)
+ context.reply(true)
case TaskSchedulerIsSet =>
scheduler = sc.taskScheduler
- }
+ context.reply(true)
+ case ExpireDeadHosts =>
+ expireDeadHosts()
+ context.reply(true)
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ // Messages received from executors
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
if (scheduler != null) {
- executorLastSeen(executorId) = System.currentTimeMillis()
- eventLoopThread.submit(new Runnable {
- override def run(): Unit = Utils.tryLogNonFatalError {
- val unknownExecutor = !scheduler.executorHeartbeatReceived(
- executorId, taskMetrics, blockManagerId)
- val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
- context.reply(response)
- }
- })
+ if (executorLastSeen.contains(executorId)) {
+ executorLastSeen(executorId) = clock.getTimeMillis()
+ eventLoopThread.submit(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ val unknownExecutor = !scheduler.executorHeartbeatReceived(
+ executorId, taskMetrics, blockManagerId)
+ val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
+ context.reply(response)
+ }
+ })
+ } else {
+ // This may happen if we get an executor's in-flight heartbeat immediately
+ // after we just removed it. It's not really an error condition so we should
+ // not log warning here. Otherwise there may be a lot of noise especially if
+ // we explicitly remove executors (SPARK-4134).
+ logDebug(s"Received heartbeat from unknown executor $executorId")
+ context.reply(HeartbeatResponse(reregisterBlockManager = true))
+ }
} else {
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
// case rarely happens. However, if it really happens, log it and ask the executor to
@@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
}
}
+ /**
+ * If the heartbeat receiver is not stopped, notify it of executor registrations.
+ */
+ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = {
+ Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId)))
+ }
+
+ /**
+ * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't
+ * log superfluous errors.
+ *
+ * Note that we must do this after the executor is actually removed to guard against the
+ * following race condition: if we remove an executor's metadata from our data structure
+ * prematurely, we may get an in-flight heartbeat from the executor before the executor is
+ * actually removed, in which case we will still mark the executor as a dead host later
+ * and expire it with loud error messages.
+ */
+ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
+ Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId)))
+ }
+
private def expireDeadHosts(): Unit = {
logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.")
- val now = System.currentTimeMillis()
+ val now = clock.getTimeMillis()
for ((executorId, lastSeenMs) <- executorLastSeen) {
if (now - lastSeenMs > executorTimeoutMs) {
logWarning(s"Removing executor $executorId with no recent heartbeats: " +
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8eed46759f..d2547eeff2 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -498,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_schedulerBackend = sched
_taskScheduler = ts
_dagScheduler = new DAGScheduler(this)
- _heartbeatReceiver.send(TaskSchedulerIsSet)
+ _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet)
// start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
// constructor
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 911b3bddd1..b31b091966 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -17,64 +17,145 @@
package org.apache.spark
-import scala.concurrent.duration._
import scala.language.postfixOps
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.storage.BlockManagerId
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.mockito.Mockito.{mock, spy, verify, when}
import org.mockito.Matchers
import org.mockito.Matchers._
-import org.apache.spark.scheduler.TaskScheduler
-import org.apache.spark.util.RpcUtils
-import org.scalatest.concurrent.Eventually._
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.ManualClock
-class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext {
+class HeartbeatReceiverSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with PrivateMethodTester
+ with LocalSparkContext {
- test("HeartbeatReceiver") {
+ private val executorId1 = "executor-1"
+ private val executorId2 = "executor-2"
+
+ // Shared state that must be reset before and after each test
+ private var scheduler: TaskScheduler = null
+ private var heartbeatReceiver: HeartbeatReceiver = null
+ private var heartbeatReceiverRef: RpcEndpointRef = null
+ private var heartbeatReceiverClock: ManualClock = null
+
+ override def beforeEach(): Unit = {
sc = spy(new SparkContext("local[2]", "test"))
- val scheduler = mock(classOf[TaskScheduler])
- when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+ scheduler = mock(classOf[TaskScheduler])
when(sc.taskScheduler).thenReturn(scheduler)
+ heartbeatReceiverClock = new ManualClock
+ heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock)
+ heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver)
+ when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+ }
- val heartbeatReceiver = new HeartbeatReceiver(sc)
- sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(heartbeatReceiver.scheduler != null)
- }
- val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+ override def afterEach(): Unit = {
+ resetSparkContext()
+ scheduler = null
+ heartbeatReceiver = null
+ heartbeatReceiverRef = null
+ heartbeatReceiverClock = null
+ }
- val metrics = new TaskMetrics
- val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
- val response = receiverRef.askWithRetry[HeartbeatResponse](
- Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+ test("task scheduler is set correctly") {
+ assert(heartbeatReceiver.scheduler === null)
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ assert(heartbeatReceiver.scheduler !== null)
+ }
- verify(scheduler).executorHeartbeatReceived(
- Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
- assert(false === response.reregisterBlockManager)
+ test("normal heartbeat") {
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ triggerHeartbeat(executorId2, executorShouldReregister = false)
+ val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ assert(trackedExecutors.size === 2)
+ assert(trackedExecutors.contains(executorId1))
+ assert(trackedExecutors.contains(executorId2))
}
- test("HeartbeatReceiver re-register") {
- sc = spy(new SparkContext("local[2]", "test"))
- val scheduler = mock(classOf[TaskScheduler])
- when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
- when(sc.taskScheduler).thenReturn(scheduler)
+ test("reregister if scheduler is not ready yet") {
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ // Task scheduler not set in HeartbeatReceiver
+ triggerHeartbeat(executorId1, executorShouldReregister = true)
+ }
- val heartbeatReceiver = new HeartbeatReceiver(sc)
- sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(heartbeatReceiver.scheduler != null)
- }
- val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+ test("reregister if heartbeat from unregistered executor") {
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ // Received heartbeat from unknown receiver, so we ask it to re-register
+ triggerHeartbeat(executorId1, executorShouldReregister = true)
+ assert(executorLastSeen(heartbeatReceiver).isEmpty)
+ }
+
+ test("reregister if heartbeat from removed executor") {
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+ // Remove the second executor but not the first
+ heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy"))
+ // Now trigger the heartbeats
+ // A heartbeat from the second executor should require reregistering
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ triggerHeartbeat(executorId2, executorShouldReregister = true)
+ val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ assert(trackedExecutors.size === 1)
+ assert(trackedExecutors.contains(executorId1))
+ assert(!trackedExecutors.contains(executorId2))
+ }
+ test("expire dead hosts") {
+ val executorTimeout = executorTimeoutMs(heartbeatReceiver)
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ triggerHeartbeat(executorId2, executorShouldReregister = false)
+ // Advance the clock and only trigger a heartbeat for the first executor
+ heartbeatReceiverClock.advance(executorTimeout / 2)
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ heartbeatReceiverClock.advance(executorTimeout)
+ heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
+ // Only the second executor should be expired as a dead host
+ verify(scheduler).executorLost(Matchers.eq(executorId2), any())
+ val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ assert(trackedExecutors.size === 1)
+ assert(trackedExecutors.contains(executorId1))
+ assert(!trackedExecutors.contains(executorId2))
+ }
+
+ /** Manually send a heartbeat and return the response. */
+ private def triggerHeartbeat(
+ executorId: String,
+ executorShouldReregister: Boolean): Unit = {
val metrics = new TaskMetrics
- val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
- val response = receiverRef.askWithRetry[HeartbeatResponse](
- Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+ val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
+ val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
+ Heartbeat(executorId, Array(1L -> metrics), blockManagerId))
+ if (executorShouldReregister) {
+ assert(response.reregisterBlockManager)
+ } else {
+ assert(!response.reregisterBlockManager)
+ // Additionally verify that the scheduler callback is called with the correct parameters
+ verify(scheduler).executorHeartbeatReceived(
+ Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+ }
+ }
- verify(scheduler).executorHeartbeatReceived(
- Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
- assert(true === response.reregisterBlockManager)
+ // Helper methods to access private fields in HeartbeatReceiver
+ private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
+ private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
+ private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = {
+ receiver invokePrivate _executorLastSeen()
+ }
+ private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = {
+ receiver invokePrivate _executorTimeoutMs()
}
+
}