aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-07-27 17:59:43 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-07-27 17:59:43 -0700
commitdaa1964b6098f79100def78451bda181b5c92198 (patch)
tree36dbff0782eaa1e0c217808fb1bf992b51e3c0e7 /streaming
parentce89ff477aea6def68265ed218f6105680755c9a (diff)
downloadspark-daa1964b6098f79100def78451bda181b5c92198.tar.gz
spark-daa1964b6098f79100def78451bda181b5c92198.tar.bz2
spark-daa1964b6098f79100def78451bda181b5c92198.zip
[SPARK-8882] [STREAMING] Add a new Receiver scheduling mechanism
The design doc: https://docs.google.com/document/d/1ZsoRvHjpISPrDmSjsGzuSu8UjwgbtmoCTzmhgTurHJw/edit?usp=sharing Author: zsxwing <zsxwing@gmail.com> Closes #7276 from zsxwing/receiver-scheduling and squashes the following commits: 137b257 [zsxwing] Add preferredNumExecutors to rescheduleReceiver 61a6c3f [zsxwing] Set state to ReceiverState.INACTIVE in deregisterReceiver 5e1fa48 [zsxwing] Fix the code style 7451498 [zsxwing] Move DummyReceiver back to ReceiverTrackerSuite 715ef9c [zsxwing] Rename: scheduledLocations -> scheduledExecutors; locations -> executors 05daf9c [zsxwing] Use receiverTrackingInfo.toReceiverInfo 1d6d7c8 [zsxwing] Merge branch 'master' into receiver-scheduling 8f93c8d [zsxwing] Use hostPort as the receiver location rather than host; fix comments and unit tests 59f8887 [zsxwing] Schedule all receivers at the same time when launching them 075e0a3 [zsxwing] Add receiver RDD name; use '!isTrackerStarted' instead 276a4ac [zsxwing] Remove "ReceiverLauncher" and move codes to "launchReceivers" fab9a01 [zsxwing] Move methods back to the outer class 4e639c4 [zsxwing] Fix unintentional changes f60d021 [zsxwing] Reorganize ReceiverTracker to use an event loop for lock free 105037e [zsxwing] Merge branch 'master' into receiver-scheduling 5fee132 [zsxwing] Update tha scheduling algorithm to avoid to keep restarting Receiver 9e242c8 [zsxwing] Remove the ScheduleReceiver message because we can refuse it when receiving RegisterReceiver a9acfbf [zsxwing] Merge branch 'squash-pr-6294' into receiver-scheduling 881edb9 [zsxwing] ReceiverScheduler -> ReceiverSchedulingPolicy e530bcc [zsxwing] [SPARK-5681][Streaming] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time #6294 3b87e4a [zsxwing] Revert SparkContext.scala a86850c [zsxwing] Remove submitAsyncJob and revert JobWaiter f549595 [zsxwing] Add comments for the scheduling approach 9ecc08e [zsxwing] Fix comments and code style 28d1bee [zsxwing] Make 'host' protected; rescheduleReceiver -> getAllowedLocations 2c86a9e [zsxwing] Use tryFailure to support calling jobFailed multiple times ca6fe35 [zsxwing] Add a test for Receiver.restart 27acd45 [zsxwing] Add unit tests for LoadBalanceReceiverSchedulerImplSuite cc76142 [zsxwing] Add JobWaiter.toFuture to avoid blocking threads d9a3e72 [zsxwing] Add a new Receiver scheduling mechanism
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala171
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala468
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala55
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala130
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala66
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala6
9 files changed, 674 insertions, 233 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
index a7c220f426..e98017a637 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent._
import scala.util.control.NonFatal
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SparkEnv, Logging, SparkConf}
import org.apache.spark.storage.StreamBlockId
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.{Utils, ThreadUtils}
/**
* Abstract class that is responsible for supervising a Receiver in the worker.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 2f6841ee88..0d802f8354 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -30,7 +30,7 @@ import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.Time
import org.apache.spark.streaming.scheduler._
import org.apache.spark.streaming.util.WriteAheadLogUtils
-import org.apache.spark.util.{RpcUtils, Utils}
+import org.apache.spark.util.RpcUtils
import org.apache.spark.{Logging, SparkEnv, SparkException}
/**
@@ -46,6 +46,8 @@ private[streaming] class ReceiverSupervisorImpl(
checkpointDirOption: Option[String]
) extends ReceiverSupervisor(receiver, env.conf) with Logging {
+ private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort
+
private val receivedBlockHandler: ReceivedBlockHandler = {
if (WriteAheadLogUtils.enableReceiverLog(env.conf)) {
if (checkpointDirOption.isEmpty) {
@@ -170,7 +172,7 @@ private[streaming] class ReceiverSupervisorImpl(
override protected def onReceiverStart(): Boolean = {
val msg = RegisterReceiver(
- streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
+ streamId, receiver.getClass.getSimpleName, hostPort, endpoint)
trackerEndpoint.askWithRetry[Boolean](msg)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala
index de85f24dd9..59df892397 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala
@@ -28,7 +28,6 @@ import org.apache.spark.rpc.RpcEndpointRef
case class ReceiverInfo(
streamId: Int,
name: String,
- private[streaming] val endpoint: RpcEndpointRef,
active: Boolean,
location: String,
lastErrorMessage: String = "",
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
new file mode 100644
index 0000000000..ef5b687b58
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.Map
+import scala.collection.mutable
+
+import org.apache.spark.streaming.receiver.Receiver
+
+private[streaming] class ReceiverSchedulingPolicy {
+
+ /**
+ * Try our best to schedule receivers with evenly distributed. However, if the
+ * `preferredLocation`s of receivers are not even, we may not be able to schedule them evenly
+ * because we have to respect them.
+ *
+ * Here is the approach to schedule executors:
+ * <ol>
+ * <li>First, schedule all the receivers with preferred locations (hosts), evenly among the
+ * executors running on those host.</li>
+ * <li>Then, schedule all other receivers evenly among all the executors such that overall
+ * distribution over all the receivers is even.</li>
+ * </ol>
+ *
+ * This method is called when we start to launch receivers at the first time.
+ */
+ def scheduleReceivers(
+ receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = {
+ if (receivers.isEmpty) {
+ return Map.empty
+ }
+
+ if (executors.isEmpty) {
+ return receivers.map(_.streamId -> Seq.empty).toMap
+ }
+
+ val hostToExecutors = executors.groupBy(_.split(":")(0))
+ val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String])
+ val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+ // Set the initial value to 0
+ executors.foreach(e => numReceiversOnExecutor(e) = 0)
+
+ // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation",
+ // we need to make sure the "preferredLocation" is in the candidate scheduled executor list.
+ for (i <- 0 until receivers.length) {
+ // Note: preferredLocation is host but executors are host:port
+ receivers(i).preferredLocation.foreach { host =>
+ hostToExecutors.get(host) match {
+ case Some(executorsOnHost) =>
+ // preferredLocation is a known host. Select an executor that has the least receivers in
+ // this host
+ val leastScheduledExecutor =
+ executorsOnHost.minBy(executor => numReceiversOnExecutor(executor))
+ scheduledExecutors(i) += leastScheduledExecutor
+ numReceiversOnExecutor(leastScheduledExecutor) =
+ numReceiversOnExecutor(leastScheduledExecutor) + 1
+ case None =>
+ // preferredLocation is an unknown host.
+ // Note: There are two cases:
+ // 1. This executor is not up. But it may be up later.
+ // 2. This executor is dead, or it's not a host in the cluster.
+ // Currently, simply add host to the scheduled executors.
+ scheduledExecutors(i) += host
+ }
+ }
+ }
+
+ // For those receivers that don't have preferredLocation, make sure we assign at least one
+ // executor to them.
+ for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) {
+ // Select the executor that has the least receivers
+ val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2)
+ scheduledExecutorsForOneReceiver += leastScheduledExecutor
+ numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1
+ }
+
+ // Assign idle executors to receivers that have less executors
+ val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1)
+ for (executor <- idleExecutors) {
+ // Assign an idle executor to the receiver that has least candidate executors.
+ val leastScheduledExecutors = scheduledExecutors.minBy(_.size)
+ leastScheduledExecutors += executor
+ }
+
+ receivers.map(_.streamId).zip(scheduledExecutors).toMap
+ }
+
+ /**
+ * Return a list of candidate executors to run the receiver. If the list is empty, the caller can
+ * run this receiver in arbitrary executor. The caller can use `preferredNumExecutors` to require
+ * returning `preferredNumExecutors` executors if possible.
+ *
+ * This method tries to balance executors' load. Here is the approach to schedule executors
+ * for a receiver.
+ * <ol>
+ * <li>
+ * If preferredLocation is set, preferredLocation should be one of the candidate executors.
+ * </li>
+ * <li>
+ * Every executor will be assigned to a weight according to the receivers running or
+ * scheduling on it.
+ * <ul>
+ * <li>
+ * If a receiver is running on an executor, it contributes 1.0 to the executor's weight.
+ * </li>
+ * <li>
+ * If a receiver is scheduled to an executor but has not yet run, it contributes
+ * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.</li>
+ * </ul>
+ * At last, if there are more than `preferredNumExecutors` idle executors (weight = 0),
+ * returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options
+ * according to the weights.
+ * </li>
+ * </ol>
+ *
+ * This method is called when a receiver is registering with ReceiverTracker or is restarting.
+ */
+ def rescheduleReceiver(
+ receiverId: Int,
+ preferredLocation: Option[String],
+ receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo],
+ executors: Seq[String],
+ preferredNumExecutors: Int = 3): Seq[String] = {
+ if (executors.isEmpty) {
+ return Seq.empty
+ }
+
+ // Always try to schedule to the preferred locations
+ val scheduledExecutors = mutable.Set[String]()
+ scheduledExecutors ++= preferredLocation
+
+ val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo =>
+ receiverTrackingInfo.state match {
+ case ReceiverState.INACTIVE => Nil
+ case ReceiverState.SCHEDULED =>
+ val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get
+ // The probability that a scheduled receiver will run in an executor is
+ // 1.0 / scheduledLocations.size
+ scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size))
+ case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0)
+ }
+ }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor
+
+ val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq
+ if (idleExecutors.size >= preferredNumExecutors) {
+ // If there are more than `preferredNumExecutors` idle executors, return all of them
+ scheduledExecutors ++= idleExecutors
+ } else {
+ // If there are less than `preferredNumExecutors` idle executors, return 3 best options
+ scheduledExecutors ++= idleExecutors
+ val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1)
+ scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors)
+ }
+ scheduledExecutors.toSeq
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 9cc6ffcd12..6270137951 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -17,17 +17,27 @@
package org.apache.spark.streaming.scheduler
-import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap}
+import java.util.concurrent.{TimeUnit, CountDownLatch}
+
+import scala.collection.mutable.HashMap
+import scala.concurrent.ExecutionContext
import scala.language.existentials
-import scala.math.max
+import scala.util.{Failure, Success}
import org.apache.spark.streaming.util.WriteAheadLogUtils
-import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
import org.apache.spark.rpc._
import org.apache.spark.streaming.{StreamingContext, Time}
-import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl,
- StopReceiver, UpdateRateLimit}
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.streaming.receiver._
+import org.apache.spark.util.{ThreadUtils, SerializableConfiguration}
+
+
+/** Enumeration to identify current state of a Receiver */
+private[streaming] object ReceiverState extends Enumeration {
+ type ReceiverState = Value
+ val INACTIVE, SCHEDULED, ACTIVE = Value
+}
/**
* Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -37,7 +47,7 @@ private[streaming] sealed trait ReceiverTrackerMessage
private[streaming] case class RegisterReceiver(
streamId: Int,
typ: String,
- host: String,
+ hostPort: String,
receiverEndpoint: RpcEndpointRef
) extends ReceiverTrackerMessage
private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo)
@@ -46,7 +56,38 @@ private[streaming] case class ReportError(streamId: Int, message: String, error:
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String)
extends ReceiverTrackerMessage
-private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage
+/**
+ * Messages used by the driver and ReceiverTrackerEndpoint to communicate locally.
+ */
+private[streaming] sealed trait ReceiverTrackerLocalMessage
+
+/**
+ * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver.
+ */
+private[streaming] case class RestartReceiver(receiver: Receiver[_])
+ extends ReceiverTrackerLocalMessage
+
+/**
+ * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers
+ * at the first time.
+ */
+private[streaming] case class StartAllReceivers(receiver: Seq[Receiver[_]])
+ extends ReceiverTrackerLocalMessage
+
+/**
+ * This message will trigger ReceiverTrackerEndpoint to send stop signals to all registered
+ * receivers.
+ */
+private[streaming] case object StopAllReceivers extends ReceiverTrackerLocalMessage
+
+/**
+ * A message used by ReceiverTracker to ask all receiver's ids still stored in
+ * ReceiverTrackerEndpoint.
+ */
+private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessage
+
+private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long)
+ extends ReceiverTrackerLocalMessage
/**
* This class manages the execution of the receivers of ReceiverInputDStreams. Instance of
@@ -60,8 +101,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
private val receiverInputStreams = ssc.graph.getReceiverInputStreams()
private val receiverInputStreamIds = receiverInputStreams.map { _.id }
- private val receiverExecutor = new ReceiverLauncher()
- private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo]
private val receivedBlockTracker = new ReceivedBlockTracker(
ssc.sparkContext.conf,
ssc.sparkContext.hadoopConfiguration,
@@ -86,6 +125,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// This not being null means the tracker has been started and not stopped
private var endpoint: RpcEndpointRef = null
+ private val schedulingPolicy = new ReceiverSchedulingPolicy()
+
+ // Track the active receiver job number. When a receiver job exits ultimately, countDown will
+ // be called.
+ private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size)
+
+ /**
+ * Track all receivers' information. The key is the receiver id, the value is the receiver info.
+ * It's only accessed in ReceiverTrackerEndpoint.
+ */
+ private val receiverTrackingInfos = new HashMap[Int, ReceiverTrackingInfo]
+
+ /**
+ * Store all preferred locations for all receivers. We need this information to schedule
+ * receivers. It's only accessed in ReceiverTrackerEndpoint.
+ */
+ private val receiverPreferredLocations = new HashMap[Int, Option[String]]
+
/** Start the endpoint and receiver execution thread. */
def start(): Unit = synchronized {
if (isTrackerStarted) {
@@ -95,7 +152,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
if (!receiverInputStreams.isEmpty) {
endpoint = ssc.env.rpcEnv.setupEndpoint(
"ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv))
- if (!skipReceiverLaunch) receiverExecutor.start()
+ if (!skipReceiverLaunch) launchReceivers()
logInfo("ReceiverTracker started")
trackerState = Started
}
@@ -112,20 +169,18 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// Wait for the Spark job that runs the receivers to be over
// That is, for the receivers to quit gracefully.
- receiverExecutor.awaitTermination(10000)
+ receiverJobExitLatch.await(10, TimeUnit.SECONDS)
if (graceful) {
- val pollTime = 100
logInfo("Waiting for receiver job to terminate gracefully")
- while (receiverInfo.nonEmpty || receiverExecutor.running) {
- Thread.sleep(pollTime)
- }
+ receiverJobExitLatch.await()
logInfo("Waited for receiver job to terminate gracefully")
}
// Check if all the receivers have been deregistered or not
- if (receiverInfo.nonEmpty) {
- logWarning("Not all of the receivers have deregistered, " + receiverInfo)
+ val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds)
+ if (receivers.nonEmpty) {
+ logWarning("Not all of the receivers have deregistered, " + receivers)
} else {
logInfo("All of the receivers have deregistered successfully")
}
@@ -154,9 +209,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
/** Get the blocks allocated to the given batch and stream. */
def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = {
- synchronized {
- receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId)
- }
+ receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId)
}
/**
@@ -170,8 +223,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// Signal the receivers to delete old block data
if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) {
logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
- receiverInfo.values.flatMap { info => Option(info.endpoint) }
- .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) }
+ endpoint.send(CleanupOldBlocks(cleanupThreshTime))
}
}
@@ -179,7 +231,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
private def registerReceiver(
streamId: Int,
typ: String,
- host: String,
+ hostPort: String,
receiverEndpoint: RpcEndpointRef,
senderAddress: RpcAddress
): Boolean = {
@@ -189,13 +241,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
if (isTrackerStopping || isTrackerStopped) {
false
+ } else if (!scheduleReceiver(streamId).contains(hostPort)) {
+ // Refuse it since it's scheduled to a wrong executor
+ false
} else {
- // "stopReceivers" won't happen at the same time because both "registerReceiver" and are
- // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If
- // "stopReceivers" is called later, it should be able to see this receiver.
- receiverInfo(streamId) = ReceiverInfo(
- streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
- listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
+ val name = s"${typ}-${streamId}"
+ val receiverTrackingInfo = ReceiverTrackingInfo(
+ streamId,
+ ReceiverState.ACTIVE,
+ scheduledExecutors = None,
+ runningExecutor = Some(hostPort),
+ name = Some(name),
+ endpoint = Some(receiverEndpoint))
+ receiverTrackingInfos.put(streamId, receiverTrackingInfo)
+ listenerBus.post(StreamingListenerReceiverStarted(receiverTrackingInfo.toReceiverInfo))
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
true
}
@@ -203,21 +262,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
/** Deregister a receiver */
private def deregisterReceiver(streamId: Int, message: String, error: String) {
- val newReceiverInfo = receiverInfo.get(streamId) match {
+ val lastErrorTime =
+ if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis()
+ val errorInfo = ReceiverErrorInfo(
+ lastErrorMessage = message, lastError = error, lastErrorTime = lastErrorTime)
+ val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match {
case Some(oldInfo) =>
- val lastErrorTime =
- if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis()
- oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message,
- lastError = error, lastErrorTime = lastErrorTime)
+ oldInfo.copy(state = ReceiverState.INACTIVE, errorInfo = Some(errorInfo))
case None =>
logWarning("No prior receiver info")
- val lastErrorTime =
- if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis()
- ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message,
- lastError = error, lastErrorTime = lastErrorTime)
+ ReceiverTrackingInfo(
+ streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo))
}
- receiverInfo -= streamId
- listenerBus.post(StreamingListenerReceiverStopped(newReceiverInfo))
+ receiverTrackingInfos -= streamId
+ listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo))
val messageWithError = if (error != null && !error.isEmpty) {
s"$message - $error"
} else {
@@ -228,9 +286,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
/** Update a receiver's maximum ingestion rate */
def sendRateUpdate(streamUID: Int, newRate: Long): Unit = {
- for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) {
- eP.send(UpdateRateLimit(newRate))
- }
+ endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
}
/** Add new blocks for the given stream */
@@ -240,16 +296,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
/** Report error sent by a receiver */
private def reportError(streamId: Int, message: String, error: String) {
- val newReceiverInfo = receiverInfo.get(streamId) match {
+ val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match {
case Some(oldInfo) =>
- oldInfo.copy(lastErrorMessage = message, lastError = error)
+ val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error,
+ lastErrorTime = oldInfo.errorInfo.map(_.lastErrorTime).getOrElse(-1L))
+ oldInfo.copy(errorInfo = Some(errorInfo))
case None =>
logWarning("No prior receiver info")
- ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message,
- lastError = error, lastErrorTime = ssc.scheduler.clock.getTimeMillis())
+ val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error,
+ lastErrorTime = ssc.scheduler.clock.getTimeMillis())
+ ReceiverTrackingInfo(
+ streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo))
}
- receiverInfo(streamId) = newReceiverInfo
- listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
+
+ receiverTrackingInfos(streamId) = newReceiverTrackingInfo
+ listenerBus.post(StreamingListenerReceiverError(newReceiverTrackingInfo.toReceiverInfo))
val messageWithError = if (error != null && !error.isEmpty) {
s"$message - $error"
} else {
@@ -258,171 +319,242 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
logWarning(s"Error reported by receiver for stream $streamId: $messageWithError")
}
+ private def scheduleReceiver(receiverId: Int): Seq[String] = {
+ val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None)
+ val scheduledExecutors = schedulingPolicy.rescheduleReceiver(
+ receiverId, preferredLocation, receiverTrackingInfos, getExecutors)
+ updateReceiverScheduledExecutors(receiverId, scheduledExecutors)
+ scheduledExecutors
+ }
+
+ private def updateReceiverScheduledExecutors(
+ receiverId: Int, scheduledExecutors: Seq[String]): Unit = {
+ val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match {
+ case Some(oldInfo) =>
+ oldInfo.copy(state = ReceiverState.SCHEDULED,
+ scheduledExecutors = Some(scheduledExecutors))
+ case None =>
+ ReceiverTrackingInfo(
+ receiverId,
+ ReceiverState.SCHEDULED,
+ Some(scheduledExecutors),
+ runningExecutor = None)
+ }
+ receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo)
+ }
+
/** Check if any blocks are left to be processed */
def hasUnallocatedBlocks: Boolean = {
receivedBlockTracker.hasUnallocatedReceivedBlocks
}
+ /**
+ * Get the list of executors excluding driver
+ */
+ private def getExecutors: Seq[String] = {
+ if (ssc.sc.isLocal) {
+ Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort)
+ } else {
+ ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) =>
+ blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location
+ }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq
+ }
+ }
+
+ /**
+ * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the
+ * receivers to be scheduled on the same node.
+ *
+ * TODO Should poll the executor number and wait for executors according to
+ * "spark.scheduler.minRegisteredResourcesRatio" and
+ * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than running a dummy job.
+ */
+ private def runDummySparkJob(): Unit = {
+ if (!ssc.sparkContext.isLocal) {
+ ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect()
+ }
+ assert(getExecutors.nonEmpty)
+ }
+
+ /**
+ * Get the receivers from the ReceiverInputDStreams, distributes them to the
+ * worker nodes as a parallel collection, and runs them.
+ */
+ private def launchReceivers(): Unit = {
+ val receivers = receiverInputStreams.map(nis => {
+ val rcvr = nis.getReceiver()
+ rcvr.setReceiverId(nis.id)
+ rcvr
+ })
+
+ runDummySparkJob()
+
+ logInfo("Starting " + receivers.length + " receivers")
+ endpoint.send(StartAllReceivers(receivers))
+ }
+
+ /** Check if tracker has been marked for starting */
+ private def isTrackerStarted: Boolean = trackerState == Started
+
+ /** Check if tracker has been marked for stopping */
+ private def isTrackerStopping: Boolean = trackerState == Stopping
+
+ /** Check if tracker has been marked for stopped */
+ private def isTrackerStopped: Boolean = trackerState == Stopped
+
/** RpcEndpoint to receive messages from the receivers. */
private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint {
+ // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged
+ private val submitJobThreadPool = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool"))
+
override def receive: PartialFunction[Any, Unit] = {
+ // Local messages
+ case StartAllReceivers(receivers) =>
+ val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors)
+ for (receiver <- receivers) {
+ val executors = scheduledExecutors(receiver.streamId)
+ updateReceiverScheduledExecutors(receiver.streamId, executors)
+ receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation
+ startReceiver(receiver, executors)
+ }
+ case RestartReceiver(receiver) =>
+ val scheduledExecutors = schedulingPolicy.rescheduleReceiver(
+ receiver.streamId,
+ receiver.preferredLocation,
+ receiverTrackingInfos,
+ getExecutors)
+ updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors)
+ startReceiver(receiver, scheduledExecutors)
+ case c: CleanupOldBlocks =>
+ receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c))
+ case UpdateReceiverRateLimit(streamUID, newRate) =>
+ for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) {
+ eP.send(UpdateRateLimit(newRate))
+ }
+ // Remote messages
case ReportError(streamId, message, error) =>
reportError(streamId, message, error)
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
+ // Remote messages
+ case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) =>
val successful =
- registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
+ registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address)
context.reply(successful)
case AddBlock(receivedBlockInfo) =>
context.reply(addBlock(receivedBlockInfo))
case DeregisterReceiver(streamId, message, error) =>
deregisterReceiver(streamId, message, error)
context.reply(true)
+ // Local messages
+ case AllReceiverIds =>
+ context.reply(receiverTrackingInfos.keys.toSeq)
case StopAllReceivers =>
assert(isTrackerStopping || isTrackerStopped)
stopReceivers()
context.reply(true)
}
- /** Send stop signal to the receivers. */
- private def stopReceivers() {
- // Signal the receivers to stop
- receiverInfo.values.flatMap { info => Option(info.endpoint)}
- .foreach { _.send(StopReceiver) }
- logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
- }
- }
-
- /** This thread class runs all the receivers on the cluster. */
- class ReceiverLauncher {
- @transient val env = ssc.env
- @volatile @transient var running = false
- @transient val thread = new Thread() {
- override def run() {
- try {
- SparkEnv.set(env)
- startReceivers()
- } catch {
- case ie: InterruptedException => logInfo("ReceiverLauncher interrupted")
- }
- }
- }
-
- def start() {
- thread.start()
- }
-
/**
- * Get the list of executors excluding driver
- */
- private def getExecutors(ssc: StreamingContext): List[String] = {
- val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList
- val driver = ssc.sparkContext.getConf.get("spark.driver.host")
- executors.diff(List(driver))
- }
-
- /** Set host location(s) for each receiver so as to distribute them over
- * executors in a round-robin fashion taking into account preferredLocation if set
+ * Start a receiver along with its scheduled executors
*/
- private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]],
- executors: List[String]): Array[ArrayBuffer[String]] = {
- val locations = new Array[ArrayBuffer[String]](receivers.length)
- var i = 0
- for (i <- 0 until receivers.length) {
- locations(i) = new ArrayBuffer[String]()
- if (receivers(i).preferredLocation.isDefined) {
- locations(i) += receivers(i).preferredLocation.get
- }
+ private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = {
+ val receiverId = receiver.streamId
+ if (!isTrackerStarted) {
+ onReceiverJobFinish(receiverId)
+ return
}
- var count = 0
- for (i <- 0 until max(receivers.length, executors.length)) {
- if (!receivers(i % receivers.length).preferredLocation.isDefined) {
- locations(i % receivers.length) += executors(count)
- count += 1
- if (count == executors.length) {
- count = 0
- }
- }
- }
- locations
- }
-
- /**
- * Get the receivers from the ReceiverInputDStreams, distributes them to the
- * worker nodes as a parallel collection, and runs them.
- */
- private def startReceivers() {
- val receivers = receiverInputStreams.map(nis => {
- val rcvr = nis.getReceiver()
- rcvr.setReceiverId(nis.id)
- rcvr
- })
val checkpointDirOption = Option(ssc.checkpointDir)
val serializableHadoopConf =
new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration)
// Function to start the receiver on the worker node
- val startReceiver = (iterator: Iterator[Receiver[_]]) => {
- if (!iterator.hasNext) {
- throw new SparkException(
- "Could not start receiver as object not found.")
- }
- val receiver = iterator.next()
- val supervisor = new ReceiverSupervisorImpl(
- receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption)
- supervisor.start()
- supervisor.awaitTermination()
- }
-
- // Run the dummy Spark job to ensure that all slaves have registered.
- // This avoids all the receivers to be scheduled on the same node.
- if (!ssc.sparkContext.isLocal) {
- ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect()
- }
+ val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf)
- // Get the list of executors and schedule receivers
- val executors = getExecutors(ssc)
- val tempRDD =
- if (!executors.isEmpty) {
- val locations = scheduleReceivers(receivers, executors)
- val roundRobinReceivers = (0 until receivers.length).map(i =>
- (receivers(i), locations(i)))
- ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers)
+ // Create the RDD using the scheduledExecutors to run the receiver in a Spark job
+ val receiverRDD: RDD[Receiver[_]] =
+ if (scheduledExecutors.isEmpty) {
+ ssc.sc.makeRDD(Seq(receiver), 1)
} else {
- ssc.sc.makeRDD(receivers, receivers.size)
+ ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors))
}
+ receiverRDD.setName(s"Receiver $receiverId")
+ val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit](
+ receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ())
+ // We will keep restarting the receiver job until ReceiverTracker is stopped
+ future.onComplete {
+ case Success(_) =>
+ if (!isTrackerStarted) {
+ onReceiverJobFinish(receiverId)
+ } else {
+ logInfo(s"Restarting Receiver $receiverId")
+ self.send(RestartReceiver(receiver))
+ }
+ case Failure(e) =>
+ if (!isTrackerStarted) {
+ onReceiverJobFinish(receiverId)
+ } else {
+ logError("Receiver has been stopped. Try to restart it.", e)
+ logInfo(s"Restarting Receiver $receiverId")
+ self.send(RestartReceiver(receiver))
+ }
+ }(submitJobThreadPool)
+ logInfo(s"Receiver ${receiver.streamId} started")
+ }
- // Distribute the receivers and start them
- logInfo("Starting " + receivers.length + " receivers")
- running = true
- try {
- ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
- logInfo("All of the receivers have been terminated")
- } finally {
- running = false
- }
+ override def onStop(): Unit = {
+ submitJobThreadPool.shutdownNow()
}
/**
- * Wait until the Spark job that runs the receivers is terminated, or return when
- * `milliseconds` elapses
+ * Call when a receiver is terminated. It means we won't restart its Spark job.
*/
- def awaitTermination(milliseconds: Long): Unit = {
- thread.join(milliseconds)
+ private def onReceiverJobFinish(receiverId: Int): Unit = {
+ receiverJobExitLatch.countDown()
+ receiverTrackingInfos.remove(receiverId).foreach { receiverTrackingInfo =>
+ if (receiverTrackingInfo.state == ReceiverState.ACTIVE) {
+ logWarning(s"Receiver $receiverId exited but didn't deregister")
+ }
+ }
}
- }
- /** Check if tracker has been marked for starting */
- private def isTrackerStarted(): Boolean = trackerState == Started
+ /** Send stop signal to the receivers. */
+ private def stopReceivers() {
+ receiverTrackingInfos.values.flatMap(_.endpoint).foreach { _.send(StopReceiver) }
+ logInfo("Sent stop signal to all " + receiverTrackingInfos.size + " receivers")
+ }
+ }
- /** Check if tracker has been marked for stopping */
- private def isTrackerStopping(): Boolean = trackerState == Stopping
+}
- /** Check if tracker has been marked for stopped */
- private def isTrackerStopped(): Boolean = trackerState == Stopped
+/**
+ * Function to start the receiver on the worker node. Use a class instead of closure to avoid
+ * the serialization issue.
+ */
+private class StartReceiverFunc(
+ checkpointDirOption: Option[String],
+ serializableHadoopConf: SerializableConfiguration)
+ extends (Iterator[Receiver[_]] => Unit) with Serializable {
+
+ override def apply(iterator: Iterator[Receiver[_]]): Unit = {
+ if (!iterator.hasNext) {
+ throw new SparkException(
+ "Could not start receiver as object not found.")
+ }
+ if (TaskContext.get().attemptNumber() == 0) {
+ val receiver = iterator.next()
+ assert(iterator.hasNext == false)
+ val supervisor = new ReceiverSupervisorImpl(
+ receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption)
+ supervisor.start()
+ supervisor.awaitTermination()
+ } else {
+ // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it.
+ }
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala
new file mode 100644
index 0000000000..043ff4d0ff
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.streaming.scheduler.ReceiverState._
+
+private[streaming] case class ReceiverErrorInfo(
+ lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)
+
+/**
+ * Class having information about a receiver.
+ *
+ * @param receiverId the unique receiver id
+ * @param state the current Receiver state
+ * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy
+ * @param runningExecutor the running executor if the receiver is active
+ * @param name the receiver name
+ * @param endpoint the receiver endpoint. It can be used to send messages to the receiver
+ * @param errorInfo the receiver error information if it fails
+ */
+private[streaming] case class ReceiverTrackingInfo(
+ receiverId: Int,
+ state: ReceiverState,
+ scheduledExecutors: Option[Seq[String]],
+ runningExecutor: Option[String],
+ name: Option[String] = None,
+ endpoint: Option[RpcEndpointRef] = None,
+ errorInfo: Option[ReceiverErrorInfo] = None) {
+
+ def toReceiverInfo: ReceiverInfo = ReceiverInfo(
+ receiverId,
+ name.getOrElse(""),
+ state == ReceiverState.ACTIVE,
+ location = runningExecutor.getOrElse(""),
+ lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
+ lastError = errorInfo.map(_.lastError).getOrElse(""),
+ lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
+ )
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
new file mode 100644
index 0000000000..93f920fdc7
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkFunSuite
+
+class ReceiverSchedulingPolicySuite extends SparkFunSuite {
+
+ val receiverSchedulingPolicy = new ReceiverSchedulingPolicy
+
+ test("rescheduleReceiver: empty executors") {
+ val scheduledExecutors =
+ receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty)
+ assert(scheduledExecutors === Seq.empty)
+ }
+
+ test("rescheduleReceiver: receiver preferredLocation") {
+ val receiverTrackingInfoMap = Map(
+ 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None))
+ val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver(
+ 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2"))
+ assert(scheduledExecutors.toSet === Set("host1", "host2"))
+ }
+
+ test("rescheduleReceiver: return all idle executors if more than 3 idle executors") {
+ val executors = Seq("host1", "host2", "host3", "host4", "host5")
+ // host3 is idle
+ val receiverTrackingInfoMap = Map(
+ 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")))
+ val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver(
+ 1, None, receiverTrackingInfoMap, executors)
+ assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5"))
+ }
+
+ test("rescheduleReceiver: return 3 best options if less than 3 idle executors") {
+ val executors = Seq("host1", "host2", "host3", "host4", "host5")
+ // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0
+ // host4 and host5 are idle
+ val receiverTrackingInfoMap = Map(
+ 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")),
+ 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None),
+ 2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None))
+ val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver(
+ 3, None, receiverTrackingInfoMap, executors)
+ assert(scheduledExecutors.toSet === Set("host2", "host4", "host5"))
+ }
+
+ test("scheduleReceivers: " +
+ "schedule receivers evenly when there are more receivers than executors") {
+ val receivers = (0 until 6).map(new DummyReceiver(_))
+ val executors = (10000 until 10003).map(port => s"localhost:${port}")
+ val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
+ val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+ // There should be 2 receivers running on each executor and each receiver has one executor
+ scheduledExecutors.foreach { case (receiverId, executors) =>
+ assert(executors.size == 1)
+ numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1
+ }
+ assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap)
+ }
+
+
+ test("scheduleReceivers: " +
+ "schedule receivers evenly when there are more executors than receivers") {
+ val receivers = (0 until 3).map(new DummyReceiver(_))
+ val executors = (10000 until 10006).map(port => s"localhost:${port}")
+ val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
+ val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+ // There should be 1 receiver running on each executor and each receiver has two executors
+ scheduledExecutors.foreach { case (receiverId, executors) =>
+ assert(executors.size == 2)
+ executors.foreach { l =>
+ numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1
+ }
+ }
+ assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap)
+ }
+
+ test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") {
+ val receivers = (0 until 3).map(new DummyReceiver(_)) ++
+ (3 until 6).map(new DummyReceiver(_, Some("localhost")))
+ val executors = (10000 until 10003).map(port => s"localhost:${port}") ++
+ (10003 until 10006).map(port => s"localhost2:${port}")
+ val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
+ val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+ // There should be 1 receiver running on each executor and each receiver has 1 executor
+ scheduledExecutors.foreach { case (receiverId, executors) =>
+ assert(executors.size == 1)
+ executors.foreach { l =>
+ numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1
+ }
+ }
+ assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap)
+ // Make sure we schedule the receivers to their preferredLocations
+ val executorsForReceiversWithPreferredLocation =
+ scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2)
+ // We can simply check the executor set because we only know each receiver only has 1 executor
+ assert(executorsForReceiversWithPreferredLocation.toSet ===
+ (10000 until 10003).map(port => s"localhost:${port}").toSet)
+ }
+
+ test("scheduleReceivers: return empty if no receiver") {
+ assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty)
+ }
+
+ test("scheduleReceivers: return empty scheduled executors if no executors") {
+ val receivers = (0 until 3).map(new DummyReceiver(_))
+ val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty)
+ scheduledExecutors.foreach { case (receiverId, executors) =>
+ assert(executors.isEmpty)
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index aadb723175..e2159bd4f2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -18,66 +18,18 @@
package org.apache.spark.streaming.scheduler
import org.scalatest.concurrent.Eventually._
-import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
-import org.apache.spark.streaming._
+
import org.apache.spark.SparkConf
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming._
import org.apache.spark.streaming.receiver._
-import org.apache.spark.util.Utils
-import org.apache.spark.streaming.dstream.InputDStream
-import scala.reflect.ClassTag
import org.apache.spark.streaming.dstream.ReceiverInputDStream
+import org.apache.spark.storage.StorageLevel
/** Testsuite for receiver scheduling */
class ReceiverTrackerSuite extends TestSuiteBase {
val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test")
val ssc = new StreamingContext(sparkConf, Milliseconds(100))
- val tracker = new ReceiverTracker(ssc)
- val launcher = new tracker.ReceiverLauncher()
- val executors: List[String] = List("0", "1", "2", "3")
-
- test("receiver scheduling - all or none have preferred location") {
-
- def parse(s: String): Array[Array[String]] = {
- val outerSplit = s.split("\\|")
- val loc = new Array[Array[String]](outerSplit.length)
- var i = 0
- for (i <- 0 until outerSplit.length) {
- loc(i) = outerSplit(i).split("\\,")
- }
- loc
- }
-
- def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) {
- val receivers =
- if (preferredLocation) {
- Array.tabulate(numReceivers)(i => new DummyReceiver(host =
- Some(((i + 1) % executors.length).toString)))
- } else {
- Array.tabulate(numReceivers)(_ => new DummyReceiver)
- }
- val locations = launcher.scheduleReceivers(receivers, executors)
- val expectedLocations = parse(allocation)
- assert(locations.deep === expectedLocations.deep)
- }
-
- testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0")
- testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2")
- testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0")
- }
-
- test("receiver scheduling - some have preferred location") {
- val numReceivers = 4;
- val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")),
- new DummyReceiver, new DummyReceiver, new DummyReceiver)
- val locations = launcher.scheduleReceivers(receivers, executors)
- assert(locations(0)(0) === "1")
- assert(locations(1)(0) === "0")
- assert(locations(2)(0) === "1")
- assert(locations(0).length === 1)
- assert(locations(3).length === 1)
- }
test("Receiver tracker - propagates rate limit") {
object ReceiverStartedWaiter extends StreamingListener {
@@ -134,19 +86,19 @@ private class RateLimitInputDStream(@transient ssc_ : StreamingContext)
* @note It's necessary to be a top-level object, or else serialization would create another
* one on the executor side and we won't be able to read its rate limit.
*/
-private object SingletonDummyReceiver extends DummyReceiver
+private object SingletonDummyReceiver extends DummyReceiver(0)
/**
* Dummy receiver implementation
*/
-private class DummyReceiver(host: Option[String] = None)
+private class DummyReceiver(receiverId: Int, host: Option[String] = None)
extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
- def onStart() {
- }
+ setReceiverId(receiverId)
- def onStop() {
- }
+ override def onStart(): Unit = {}
+
+ override def onStop(): Unit = {}
override def preferredLocation: Option[String] = host
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
index 40dc1fb601..0891309f95 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
@@ -119,20 +119,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
listener.numTotalReceivedRecords should be (600)
// onReceiverStarted
- val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost")
+ val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost")
listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted))
listener.receiverInfo(0) should be (Some(receiverInfoStarted))
listener.receiverInfo(1) should be (None)
// onReceiverError
- val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost")
+ val receiverInfoError = ReceiverInfo(1, "test", true, "localhost")
listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError))
listener.receiverInfo(0) should be (Some(receiverInfoStarted))
listener.receiverInfo(1) should be (Some(receiverInfoError))
listener.receiverInfo(2) should be (None)
// onReceiverStopped
- val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost")
+ val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost")
listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped))
listener.receiverInfo(0) should be (Some(receiverInfoStarted))
listener.receiverInfo(1) should be (Some(receiverInfoError))