aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala24
3 files changed, 29 insertions, 2 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
index 10b5a7f57a..d2b0be7f4a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
@@ -21,6 +21,7 @@ import scala.collection.Map
import scala.collection.mutable
import org.apache.spark.streaming.receiver.Receiver
+import org.apache.spark.util.Utils
/**
* A class that tries to schedule receivers with evenly distributed. There are two phases for
@@ -79,7 +80,7 @@ private[streaming] class ReceiverSchedulingPolicy {
return receivers.map(_.streamId -> Seq.empty).toMap
}
- val hostToExecutors = executors.groupBy(_.split(":")(0))
+ val hostToExecutors = executors.groupBy(executor => Utils.parseHostPort(executor)._1)
val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String])
val numReceiversOnExecutor = mutable.HashMap[String, Int]()
// Set the initial value to 0
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index d053e9e849..2ce80d618b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -551,7 +551,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
if (scheduledExecutors.isEmpty) {
ssc.sc.makeRDD(Seq(receiver), 1)
} else {
- ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors))
+ val preferredLocations =
+ scheduledExecutors.map(hostPort => Utils.parseHostPort(hostPort)._1).distinct
+ ssc.sc.makeRDD(Seq(receiver -> preferredLocations))
}
receiverRDD.setName(s"Receiver $receiverId")
ssc.sparkContext.setJobDescription(s"Streaming job running receiver $receiverId")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index 45138b748e..fda86aef45 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLocality}
+import org.apache.spark.scheduler.TaskLocality.TaskLocality
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
import org.apache.spark.streaming.dstream.ReceiverInputDStream
@@ -80,6 +82,28 @@ class ReceiverTrackerSuite extends TestSuiteBase {
}
}
}
+
+ test("SPARK-11063: TaskSetManager should use Receiver RDD's preferredLocations") {
+ // Use ManualClock to prevent from starting batches so that we can make sure the only task is
+ // for starting the Receiver
+ val _conf = conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock")
+ withStreamingContext(new StreamingContext(_conf, Milliseconds(100))) { ssc =>
+ @volatile var receiverTaskLocality: TaskLocality = null
+ ssc.sparkContext.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+ receiverTaskLocality = taskStart.taskInfo.taskLocality
+ }
+ })
+ val input = ssc.receiverStream(new TestReceiver)
+ val output = new TestOutputStream(input)
+ output.register()
+ ssc.start()
+ eventually(timeout(10 seconds), interval(10 millis)) {
+ // If preferredLocations is set correctly, receiverTaskLocality should be NODE_LOCAL
+ assert(receiverTaskLocality === TaskLocality.NODE_LOCAL)
+ }
+ }
+ }
}
/** An input DStream with for testing rate controlling */