From daa1964b6098f79100def78451bda181b5c92198 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 27 Jul 2015 17:59:43 -0700 Subject: [SPARK-8882] [STREAMING] Add a new Receiver scheduling mechanism The design doc: https://docs.google.com/document/d/1ZsoRvHjpISPrDmSjsGzuSu8UjwgbtmoCTzmhgTurHJw/edit?usp=sharing Author: zsxwing Closes #7276 from zsxwing/receiver-scheduling and squashes the following commits: 137b257 [zsxwing] Add preferredNumExecutors to rescheduleReceiver 61a6c3f [zsxwing] Set state to ReceiverState.INACTIVE in deregisterReceiver 5e1fa48 [zsxwing] Fix the code style 7451498 [zsxwing] Move DummyReceiver back to ReceiverTrackerSuite 715ef9c [zsxwing] Rename: scheduledLocations -> scheduledExecutors; locations -> executors 05daf9c [zsxwing] Use receiverTrackingInfo.toReceiverInfo 1d6d7c8 [zsxwing] Merge branch 'master' into receiver-scheduling 8f93c8d [zsxwing] Use hostPort as the receiver location rather than host; fix comments and unit tests 59f8887 [zsxwing] Schedule all receivers at the same time when launching them 075e0a3 [zsxwing] Add receiver RDD name; use '!isTrackerStarted' instead 276a4ac [zsxwing] Remove "ReceiverLauncher" and move codes to "launchReceivers" fab9a01 [zsxwing] Move methods back to the outer class 4e639c4 [zsxwing] Fix unintentional changes f60d021 [zsxwing] Reorganize ReceiverTracker to use an event loop for lock free 105037e [zsxwing] Merge branch 'master' into receiver-scheduling 5fee132 [zsxwing] Update tha scheduling algorithm to avoid to keep restarting Receiver 9e242c8 [zsxwing] Remove the ScheduleReceiver message because we can refuse it when receiving RegisterReceiver a9acfbf [zsxwing] Merge branch 'squash-pr-6294' into receiver-scheduling 881edb9 [zsxwing] ReceiverScheduler -> ReceiverSchedulingPolicy e530bcc [zsxwing] [SPARK-5681][Streaming] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time #6294 3b87e4a [zsxwing] Revert SparkContext.scala a86850c [zsxwing] Remove submitAsyncJob and revert JobWaiter f549595 [zsxwing] Add comments for the scheduling approach 9ecc08e [zsxwing] Fix comments and code style 28d1bee [zsxwing] Make 'host' protected; rescheduleReceiver -> getAllowedLocations 2c86a9e [zsxwing] Use tryFailure to support calling jobFailed multiple times ca6fe35 [zsxwing] Add a test for Receiver.restart 27acd45 [zsxwing] Add unit tests for LoadBalanceReceiverSchedulerImplSuite cc76142 [zsxwing] Add JobWaiter.toFuture to avoid blocking threads d9a3e72 [zsxwing] Add a new Receiver scheduling mechanism --- .../streaming/receiver/ReceiverSupervisor.scala | 4 +- .../receiver/ReceiverSupervisorImpl.scala | 6 +- .../spark/streaming/scheduler/ReceiverInfo.scala | 1 - .../scheduler/ReceiverSchedulingPolicy.scala | 171 ++++++++ .../streaming/scheduler/ReceiverTracker.scala | 468 +++++++++++++-------- .../streaming/scheduler/ReceiverTrackingInfo.scala | 55 +++ .../scheduler/ReceiverSchedulingPolicySuite.scala | 130 ++++++ .../streaming/scheduler/ReceiverTrackerSuite.scala | 66 +-- .../ui/StreamingJobProgressListenerSuite.scala | 6 +- 9 files changed, 674 insertions(+), 233 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala (limited to 'streaming') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index a7c220f426..e98017a637 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkEnv, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{Utils, ThreadUtils} /** * Abstract class that is responsible for supervising a Receiver in the worker. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 2f6841ee88..0d802f8354 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.RpcUtils import org.apache.spark.{Logging, SparkEnv, SparkException} /** @@ -46,6 +46,8 @@ private[streaming] class ReceiverSupervisorImpl( checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { + private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort + private val receivedBlockHandler: ReceivedBlockHandler = { if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { if (checkpointDirOption.isEmpty) { @@ -170,7 +172,7 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + streamId, receiver.getClass.getSimpleName, hostPort, endpoint) trackerEndpoint.askWithRetry[Boolean](msg) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index de85f24dd9..59df892397 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -28,7 +28,6 @@ import org.apache.spark.rpc.RpcEndpointRef case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val endpoint: RpcEndpointRef, active: Boolean, location: String, lastErrorMessage: String = "", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala new file mode 100644 index 0000000000..ef5b687b58 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.Map +import scala.collection.mutable + +import org.apache.spark.streaming.receiver.Receiver + +private[streaming] class ReceiverSchedulingPolicy { + + /** + * Try our best to schedule receivers with evenly distributed. However, if the + * `preferredLocation`s of receivers are not even, we may not be able to schedule them evenly + * because we have to respect them. + * + * Here is the approach to schedule executors: + *
    + *
  1. First, schedule all the receivers with preferred locations (hosts), evenly among the + * executors running on those host.
  2. + *
  3. Then, schedule all other receivers evenly among all the executors such that overall + * distribution over all the receivers is even.
  4. + *
+ * + * This method is called when we start to launch receivers at the first time. + */ + def scheduleReceivers( + receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = { + if (receivers.isEmpty) { + return Map.empty + } + + if (executors.isEmpty) { + return receivers.map(_.streamId -> Seq.empty).toMap + } + + val hostToExecutors = executors.groupBy(_.split(":")(0)) + val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // Set the initial value to 0 + executors.foreach(e => numReceiversOnExecutor(e) = 0) + + // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation", + // we need to make sure the "preferredLocation" is in the candidate scheduled executor list. + for (i <- 0 until receivers.length) { + // Note: preferredLocation is host but executors are host:port + receivers(i).preferredLocation.foreach { host => + hostToExecutors.get(host) match { + case Some(executorsOnHost) => + // preferredLocation is a known host. Select an executor that has the least receivers in + // this host + val leastScheduledExecutor = + executorsOnHost.minBy(executor => numReceiversOnExecutor(executor)) + scheduledExecutors(i) += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = + numReceiversOnExecutor(leastScheduledExecutor) + 1 + case None => + // preferredLocation is an unknown host. + // Note: There are two cases: + // 1. This executor is not up. But it may be up later. + // 2. This executor is dead, or it's not a host in the cluster. + // Currently, simply add host to the scheduled executors. + scheduledExecutors(i) += host + } + } + } + + // For those receivers that don't have preferredLocation, make sure we assign at least one + // executor to them. + for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) { + // Select the executor that has the least receivers + val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2) + scheduledExecutorsForOneReceiver += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1 + } + + // Assign idle executors to receivers that have less executors + val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1) + for (executor <- idleExecutors) { + // Assign an idle executor to the receiver that has least candidate executors. + val leastScheduledExecutors = scheduledExecutors.minBy(_.size) + leastScheduledExecutors += executor + } + + receivers.map(_.streamId).zip(scheduledExecutors).toMap + } + + /** + * Return a list of candidate executors to run the receiver. If the list is empty, the caller can + * run this receiver in arbitrary executor. The caller can use `preferredNumExecutors` to require + * returning `preferredNumExecutors` executors if possible. + * + * This method tries to balance executors' load. Here is the approach to schedule executors + * for a receiver. + *
    + *
  1. + * If preferredLocation is set, preferredLocation should be one of the candidate executors. + *
  2. + *
  3. + * Every executor will be assigned to a weight according to the receivers running or + * scheduling on it. + *
      + *
    • + * If a receiver is running on an executor, it contributes 1.0 to the executor's weight. + *
    • + *
    • + * If a receiver is scheduled to an executor but has not yet run, it contributes + * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.
    • + *
    + * At last, if there are more than `preferredNumExecutors` idle executors (weight = 0), + * returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options + * according to the weights. + *
  4. + *
+ * + * This method is called when a receiver is registering with ReceiverTracker or is restarting. + */ + def rescheduleReceiver( + receiverId: Int, + preferredLocation: Option[String], + receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], + executors: Seq[String], + preferredNumExecutors: Int = 3): Seq[String] = { + if (executors.isEmpty) { + return Seq.empty + } + + // Always try to schedule to the preferred locations + val scheduledExecutors = mutable.Set[String]() + scheduledExecutors ++= preferredLocation + + val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo => + receiverTrackingInfo.state match { + case ReceiverState.INACTIVE => Nil + case ReceiverState.SCHEDULED => + val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get + // The probability that a scheduled receiver will run in an executor is + // 1.0 / scheduledLocations.size + scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size)) + case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) + } + }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + + val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq + if (idleExecutors.size >= preferredNumExecutors) { + // If there are more than `preferredNumExecutors` idle executors, return all of them + scheduledExecutors ++= idleExecutors + } else { + // If there are less than `preferredNumExecutors` idle executors, return 3 best options + scheduledExecutors ++= idleExecutors + val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1) + scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors) + } + scheduledExecutors.toSeq + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 9cc6ffcd12..6270137951 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,17 +17,27 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} +import java.util.concurrent.{TimeUnit, CountDownLatch} + +import scala.collection.mutable.HashMap +import scala.concurrent.ExecutionContext import scala.language.existentials -import scala.math.max +import scala.util.{Failure, Success} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, - StopReceiver, UpdateRateLimit} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.{ThreadUtils, SerializableConfiguration} + + +/** Enumeration to identify current state of a Receiver */ +private[streaming] object ReceiverState extends Enumeration { + type ReceiverState = Value + val INACTIVE, SCHEDULED, ACTIVE = Value +} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -37,7 +47,7 @@ private[streaming] sealed trait ReceiverTrackerMessage private[streaming] case class RegisterReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) @@ -46,7 +56,38 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage -private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage +/** + * Messages used by the driver and ReceiverTrackerEndpoint to communicate locally. + */ +private[streaming] sealed trait ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver. + */ +private[streaming] case class RestartReceiver(receiver: Receiver[_]) + extends ReceiverTrackerLocalMessage + +/** + * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers + * at the first time. + */ +private[streaming] case class StartAllReceivers(receiver: Seq[Receiver[_]]) + extends ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to send stop signals to all registered + * receivers. + */ +private[streaming] case object StopAllReceivers extends ReceiverTrackerLocalMessage + +/** + * A message used by ReceiverTracker to ask all receiver's ids still stored in + * ReceiverTrackerEndpoint. + */ +private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessage + +private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long) + extends ReceiverTrackerLocalMessage /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of @@ -60,8 +101,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private val receiverInputStreams = ssc.graph.getReceiverInputStreams() private val receiverInputStreamIds = receiverInputStreams.map { _.id } - private val receiverExecutor = new ReceiverLauncher() - private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] private val receivedBlockTracker = new ReceivedBlockTracker( ssc.sparkContext.conf, ssc.sparkContext.hadoopConfiguration, @@ -86,6 +125,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null + private val schedulingPolicy = new ReceiverSchedulingPolicy() + + // Track the active receiver job number. When a receiver job exits ultimately, countDown will + // be called. + private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size) + + /** + * Track all receivers' information. The key is the receiver id, the value is the receiver info. + * It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverTrackingInfos = new HashMap[Int, ReceiverTrackingInfo] + + /** + * Store all preferred locations for all receivers. We need this information to schedule + * receivers. It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverPreferredLocations = new HashMap[Int, Option[String]] + /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { if (isTrackerStarted) { @@ -95,7 +152,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (!receiverInputStreams.isEmpty) { endpoint = ssc.env.rpcEnv.setupEndpoint( "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) - if (!skipReceiverLaunch) receiverExecutor.start() + if (!skipReceiverLaunch) launchReceivers() logInfo("ReceiverTracker started") trackerState = Started } @@ -112,20 +169,18 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Wait for the Spark job that runs the receivers to be over // That is, for the receivers to quit gracefully. - receiverExecutor.awaitTermination(10000) + receiverJobExitLatch.await(10, TimeUnit.SECONDS) if (graceful) { - val pollTime = 100 logInfo("Waiting for receiver job to terminate gracefully") - while (receiverInfo.nonEmpty || receiverExecutor.running) { - Thread.sleep(pollTime) - } + receiverJobExitLatch.await() logInfo("Waited for receiver job to terminate gracefully") } // Check if all the receivers have been deregistered or not - if (receiverInfo.nonEmpty) { - logWarning("Not all of the receivers have deregistered, " + receiverInfo) + val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds) + if (receivers.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receivers) } else { logInfo("All of the receivers have deregistered successfully") } @@ -154,9 +209,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Get the blocks allocated to the given batch and stream. */ def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { - synchronized { - receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) - } + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) } /** @@ -170,8 +223,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.endpoint) } - .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) } } @@ -179,7 +231,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private def registerReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress ): Boolean = { @@ -189,13 +241,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (isTrackerStopping || isTrackerStopped) { false + } else if (!scheduleReceiver(streamId).contains(hostPort)) { + // Refuse it since it's scheduled to a wrong executor + false } else { - // "stopReceivers" won't happen at the same time because both "registerReceiver" and are - // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If - // "stopReceivers" is called later, it should be able to see this receiver. - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + val name = s"${typ}-${streamId}" + val receiverTrackingInfo = ReceiverTrackingInfo( + streamId, + ReceiverState.ACTIVE, + scheduledExecutors = None, + runningExecutor = Some(hostPort), + name = Some(name), + endpoint = Some(receiverEndpoint)) + receiverTrackingInfos.put(streamId, receiverTrackingInfo) + listenerBus.post(StreamingListenerReceiverStarted(receiverTrackingInfo.toReceiverInfo)) logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) true } @@ -203,21 +262,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val lastErrorTime = + if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() + val errorInfo = ReceiverErrorInfo( + lastErrorMessage = message, lastError = error, lastErrorTime = lastErrorTime) + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + oldInfo.copy(state = ReceiverState.INACTIVE, errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo -= streamId - listenerBus.post(StreamingListenerReceiverStopped(newReceiverInfo)) + receiverTrackingInfos -= streamId + listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -228,9 +286,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Update a receiver's maximum ingestion rate */ def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { - for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) { - eP.send(UpdateRateLimit(newRate)) - } + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) } /** Add new blocks for the given stream */ @@ -240,16 +296,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Report error sent by a receiver */ private def reportError(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(lastErrorMessage = message, lastError = error) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = oldInfo.errorInfo.map(_.lastErrorTime).getOrElse(-1L)) + oldInfo.copy(errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo(streamId) = newReceiverInfo - listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + + receiverTrackingInfos(streamId) = newReceiverTrackingInfo + listenerBus.post(StreamingListenerReceiverError(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -258,171 +319,242 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logWarning(s"Error reported by receiver for stream $streamId: $messageWithError") } + private def scheduleReceiver(receiverId: Int): Seq[String] = { + val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None) + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiverId, preferredLocation, receiverTrackingInfos, getExecutors) + updateReceiverScheduledExecutors(receiverId, scheduledExecutors) + scheduledExecutors + } + + private def updateReceiverScheduledExecutors( + receiverId: Int, scheduledExecutors: Seq[String]): Unit = { + val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match { + case Some(oldInfo) => + oldInfo.copy(state = ReceiverState.SCHEDULED, + scheduledExecutors = Some(scheduledExecutors)) + case None => + ReceiverTrackingInfo( + receiverId, + ReceiverState.SCHEDULED, + Some(scheduledExecutors), + runningExecutor = None) + } + receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo) + } + /** Check if any blocks are left to be processed */ def hasUnallocatedBlocks: Boolean = { receivedBlockTracker.hasUnallocatedReceivedBlocks } + /** + * Get the list of executors excluding driver + */ + private def getExecutors: Seq[String] = { + if (ssc.sc.isLocal) { + Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort) + } else { + ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => + blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location + }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq + } + } + + /** + * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the + * receivers to be scheduled on the same node. + * + * TODO Should poll the executor number and wait for executors according to + * "spark.scheduler.minRegisteredResourcesRatio" and + * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than running a dummy job. + */ + private def runDummySparkJob(): Unit = { + if (!ssc.sparkContext.isLocal) { + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + } + assert(getExecutors.nonEmpty) + } + + /** + * Get the receivers from the ReceiverInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ + private def launchReceivers(): Unit = { + val receivers = receiverInputStreams.map(nis => { + val rcvr = nis.getReceiver() + rcvr.setReceiverId(nis.id) + rcvr + }) + + runDummySparkJob() + + logInfo("Starting " + receivers.length + " receivers") + endpoint.send(StartAllReceivers(receivers)) + } + + /** Check if tracker has been marked for starting */ + private def isTrackerStarted: Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping: Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped: Boolean = trackerState == Stopped + /** RpcEndpoint to receive messages from the receivers. */ private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged + private val submitJobThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + override def receive: PartialFunction[Any, Unit] = { + // Local messages + case StartAllReceivers(receivers) => + val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors) + for (receiver <- receivers) { + val executors = scheduledExecutors(receiver.streamId) + updateReceiverScheduledExecutors(receiver.streamId, executors) + receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation + startReceiver(receiver, executors) + } + case RestartReceiver(receiver) => + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors) + startReceiver(receiver, scheduledExecutors) + case c: CleanupOldBlocks => + receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) + case UpdateReceiverRateLimit(streamUID, newRate) => + for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) { + eP.send(UpdateRateLimit(newRate)) + } + // Remote messages case ReportError(streamId, message, error) => reportError(streamId, message, error) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterReceiver(streamId, typ, host, receiverEndpoint) => + // Remote messages + case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + // Local messages + case AllReceiverIds => + context.reply(receiverTrackingInfos.keys.toSeq) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() context.reply(true) } - /** Send stop signal to the receivers. */ - private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.endpoint)} - .foreach { _.send(StopReceiver) } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") - } - } - - /** This thread class runs all the receivers on the cluster. */ - class ReceiverLauncher { - @transient val env = ssc.env - @volatile @transient var running = false - @transient val thread = new Thread() { - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") - } - } - } - - def start() { - thread.start() - } - /** - * Get the list of executors excluding driver - */ - private def getExecutors(ssc: StreamingContext): List[String] = { - val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList - val driver = ssc.sparkContext.getConf.get("spark.driver.host") - executors.diff(List(driver)) - } - - /** Set host location(s) for each receiver so as to distribute them over - * executors in a round-robin fashion taking into account preferredLocation if set + * Start a receiver along with its scheduled executors */ - private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]], - executors: List[String]): Array[ArrayBuffer[String]] = { - val locations = new Array[ArrayBuffer[String]](receivers.length) - var i = 0 - for (i <- 0 until receivers.length) { - locations(i) = new ArrayBuffer[String]() - if (receivers(i).preferredLocation.isDefined) { - locations(i) += receivers(i).preferredLocation.get - } + private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + val receiverId = receiver.streamId + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + return } - var count = 0 - for (i <- 0 until max(receivers.length, executors.length)) { - if (!receivers(i % receivers.length).preferredLocation.isDefined) { - locations(i % receivers.length) += executors(count) - count += 1 - if (count == executors.length) { - count = 0 - } - } - } - locations - } - - /** - * Get the receivers from the ReceiverInputDStreams, distributes them to the - * worker nodes as a parallel collection, and runs them. - */ - private def startReceivers() { - val receivers = receiverInputStreams.map(nis => { - val rcvr = nis.getReceiver() - rcvr.setReceiverId(nis.id) - rcvr - }) val checkpointDirOption = Option(ssc.checkpointDir) val serializableHadoopConf = new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiver = (iterator: Iterator[Receiver[_]]) => { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") - } - val receiver = iterator.next() - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } - - // Run the dummy Spark job to ensure that all slaves have registered. - // This avoids all the receivers to be scheduled on the same node. - if (!ssc.sparkContext.isLocal) { - ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() - } + val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf) - // Get the list of executors and schedule receivers - val executors = getExecutors(ssc) - val tempRDD = - if (!executors.isEmpty) { - val locations = scheduleReceivers(receivers, executors) - val roundRobinReceivers = (0 until receivers.length).map(i => - (receivers(i), locations(i))) - ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers) + // Create the RDD using the scheduledExecutors to run the receiver in a Spark job + val receiverRDD: RDD[Receiver[_]] = + if (scheduledExecutors.isEmpty) { + ssc.sc.makeRDD(Seq(receiver), 1) } else { - ssc.sc.makeRDD(receivers, receivers.size) + ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) } + receiverRDD.setName(s"Receiver $receiverId") + val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit]( + receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ()) + // We will keep restarting the receiver job until ReceiverTracker is stopped + future.onComplete { + case Success(_) => + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + } else { + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + case Failure(e) => + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + } else { + logError("Receiver has been stopped. Try to restart it.", e) + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + }(submitJobThreadPool) + logInfo(s"Receiver ${receiver.streamId} started") + } - // Distribute the receivers and start them - logInfo("Starting " + receivers.length + " receivers") - running = true - try { - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - logInfo("All of the receivers have been terminated") - } finally { - running = false - } + override def onStop(): Unit = { + submitJobThreadPool.shutdownNow() } /** - * Wait until the Spark job that runs the receivers is terminated, or return when - * `milliseconds` elapses + * Call when a receiver is terminated. It means we won't restart its Spark job. */ - def awaitTermination(milliseconds: Long): Unit = { - thread.join(milliseconds) + private def onReceiverJobFinish(receiverId: Int): Unit = { + receiverJobExitLatch.countDown() + receiverTrackingInfos.remove(receiverId).foreach { receiverTrackingInfo => + if (receiverTrackingInfo.state == ReceiverState.ACTIVE) { + logWarning(s"Receiver $receiverId exited but didn't deregister") + } + } } - } - /** Check if tracker has been marked for starting */ - private def isTrackerStarted(): Boolean = trackerState == Started + /** Send stop signal to the receivers. */ + private def stopReceivers() { + receiverTrackingInfos.values.flatMap(_.endpoint).foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverTrackingInfos.size + " receivers") + } + } - /** Check if tracker has been marked for stopping */ - private def isTrackerStopping(): Boolean = trackerState == Stopping +} - /** Check if tracker has been marked for stopped */ - private def isTrackerStopped(): Boolean = trackerState == Stopped +/** + * Function to start the receiver on the worker node. Use a class instead of closure to avoid + * the serialization issue. + */ +private class StartReceiverFunc( + checkpointDirOption: Option[String], + serializableHadoopConf: SerializableConfiguration) + extends (Iterator[Receiver[_]] => Unit) with Serializable { + + override def apply(iterator: Iterator[Receiver[_]]): Unit = { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala new file mode 100644 index 0000000000..043ff4d0ff --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.streaming.scheduler.ReceiverState._ + +private[streaming] case class ReceiverErrorInfo( + lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L) + +/** + * Class having information about a receiver. + * + * @param receiverId the unique receiver id + * @param state the current Receiver state + * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy + * @param runningExecutor the running executor if the receiver is active + * @param name the receiver name + * @param endpoint the receiver endpoint. It can be used to send messages to the receiver + * @param errorInfo the receiver error information if it fails + */ +private[streaming] case class ReceiverTrackingInfo( + receiverId: Int, + state: ReceiverState, + scheduledExecutors: Option[Seq[String]], + runningExecutor: Option[String], + name: Option[String] = None, + endpoint: Option[RpcEndpointRef] = None, + errorInfo: Option[ReceiverErrorInfo] = None) { + + def toReceiverInfo: ReceiverInfo = ReceiverInfo( + receiverId, + name.getOrElse(""), + state == ReceiverState.ACTIVE, + location = runningExecutor.getOrElse(""), + lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), + lastError = errorInfo.map(_.lastError).getOrElse(""), + lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) + ) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala new file mode 100644 index 0000000000..93f920fdc7 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite + +class ReceiverSchedulingPolicySuite extends SparkFunSuite { + + val receiverSchedulingPolicy = new ReceiverSchedulingPolicy + + test("rescheduleReceiver: empty executors") { + val scheduledExecutors = + receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty) + assert(scheduledExecutors === Seq.empty) + } + + test("rescheduleReceiver: receiver preferredLocation") { + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2")) + assert(scheduledExecutors.toSet === Set("host1", "host2")) + } + + test("rescheduleReceiver: return all idle executors if more than 3 idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // host3 is idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1"))) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 1, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) + } + + test("rescheduleReceiver: return 3 best options if less than 3 idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0 + // host4 and host5 are idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), + 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), + 2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 3, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) + } + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more receivers than executors") { + val receivers = (0 until 6).map(new DummyReceiver(_)) + val executors = (10000 until 10003).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 2 receivers running on each executor and each receiver has one executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1 + } + assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap) + } + + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more executors than receivers") { + val receivers = (0 until 3).map(new DummyReceiver(_)) + val executors = (10000 until 10006).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has two executors + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 2) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + } + + test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { + val receivers = (0 until 3).map(new DummyReceiver(_)) ++ + (3 until 6).map(new DummyReceiver(_, Some("localhost"))) + val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ + (10003 until 10006).map(port => s"localhost2:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has 1 executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + // Make sure we schedule the receivers to their preferredLocations + val executorsForReceiversWithPreferredLocation = + scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) + // We can simply check the executor set because we only know each receiver only has 1 executor + assert(executorsForReceiversWithPreferredLocation.toSet === + (10000 until 10003).map(port => s"localhost:${port}").toSet) + } + + test("scheduleReceivers: return empty if no receiver") { + assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty) + } + + test("scheduleReceivers: return empty scheduled executors if no executors") { + val receivers = (0 until 3).map(new DummyReceiver(_)) + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.isEmpty) + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index aadb723175..e2159bd4f2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -18,66 +18,18 @@ package org.apache.spark.streaming.scheduler import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.streaming._ + import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ import org.apache.spark.streaming.receiver._ -import org.apache.spark.util.Utils -import org.apache.spark.streaming.dstream.InputDStream -import scala.reflect.ClassTag import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.storage.StorageLevel /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") val ssc = new StreamingContext(sparkConf, Milliseconds(100)) - val tracker = new ReceiverTracker(ssc) - val launcher = new tracker.ReceiverLauncher() - val executors: List[String] = List("0", "1", "2", "3") - - test("receiver scheduling - all or none have preferred location") { - - def parse(s: String): Array[Array[String]] = { - val outerSplit = s.split("\\|") - val loc = new Array[Array[String]](outerSplit.length) - var i = 0 - for (i <- 0 until outerSplit.length) { - loc(i) = outerSplit(i).split("\\,") - } - loc - } - - def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { - val receivers = - if (preferredLocation) { - Array.tabulate(numReceivers)(i => new DummyReceiver(host = - Some(((i + 1) % executors.length).toString))) - } else { - Array.tabulate(numReceivers)(_ => new DummyReceiver) - } - val locations = launcher.scheduleReceivers(receivers, executors) - val expectedLocations = parse(allocation) - assert(locations.deep === expectedLocations.deep) - } - - testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") - testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") - testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") - } - - test("receiver scheduling - some have preferred location") { - val numReceivers = 4; - val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), - new DummyReceiver, new DummyReceiver, new DummyReceiver) - val locations = launcher.scheduleReceivers(receivers, executors) - assert(locations(0)(0) === "1") - assert(locations(1)(0) === "0") - assert(locations(2)(0) === "1") - assert(locations(0).length === 1) - assert(locations(3).length === 1) - } test("Receiver tracker - propagates rate limit") { object ReceiverStartedWaiter extends StreamingListener { @@ -134,19 +86,19 @@ private class RateLimitInputDStream(@transient ssc_ : StreamingContext) * @note It's necessary to be a top-level object, or else serialization would create another * one on the executor side and we won't be able to read its rate limit. */ -private object SingletonDummyReceiver extends DummyReceiver +private object SingletonDummyReceiver extends DummyReceiver(0) /** * Dummy receiver implementation */ -private class DummyReceiver(host: Option[String] = None) +private class DummyReceiver(receiverId: Int, host: Option[String] = None) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - def onStart() { - } + setReceiverId(receiverId) - def onStop() { - } + override def onStart(): Unit = {} + + override def onStop(): Unit = {} override def preferredLocation: Option[String] = host } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 40dc1fb601..0891309f95 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -119,20 +119,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) -- cgit v1.2.3