aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala26
-rw-r--r--external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala35
2 files changed, 46 insertions, 15 deletions
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
index 77661f71ad..1ef91dd492 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
@@ -55,14 +55,14 @@ class MQTTInputDStream(
brokerUrl: String,
topic: String,
storageLevel: StorageLevel
- ) extends ReceiverInputDStream[String](ssc_) with Logging {
-
+ ) extends ReceiverInputDStream[String](ssc_) {
+
def getReceiver(): Receiver[String] = {
new MQTTReceiver(brokerUrl, topic, storageLevel)
}
}
-private[streaming]
+private[streaming]
class MQTTReceiver(
brokerUrl: String,
topic: String,
@@ -72,21 +72,15 @@ class MQTTReceiver(
def onStop() {
}
-
+
def onStart() {
- // Set up persistence for messages
+ // Set up persistence for messages
val persistence = new MemoryPersistence()
// Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence)
- // Connect to MqttBroker
- client.connect()
-
- // Subscribe to Mqtt topic
- client.subscribe(topic)
-
// Callback automatically triggers as and when new message arrives on specified topic
val callback: MqttCallback = new MqttCallback() {
@@ -103,7 +97,15 @@ class MQTTReceiver(
}
}
- // Set up callback for MqttClient
+ // Set up callback for MqttClient. This needs to happen before
+ // connecting or subscribing, otherwise messages may be lost
client.setCallback(callback)
+
+ // Connect to MqttBroker
+ client.connect()
+
+ // Subscribe to Mqtt topic
+ client.subscribe(topic)
+
}
}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index fe53a29cba..e84adc088a 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.streaming.mqtt
import java.net.{URI, ServerSocket}
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -32,6 +34,8 @@ import org.scalatest.concurrent.Eventually
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
+import org.apache.spark.streaming.scheduler.StreamingListener
+import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.SparkConf
import org.apache.spark.util.Utils
@@ -67,7 +71,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
val sendMessage = "MQTT demo for spark streaming"
val receiveStream: ReceiverInputDStream[String] =
MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
- var receiveMessage: List[String] = List()
+ @volatile var receiveMessage: List[String] = List()
receiveStream.foreachRDD { rdd =>
if (rdd.collect.length > 0) {
receiveMessage = receiveMessage ::: List(rdd.first)
@@ -75,6 +79,11 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
}
}
ssc.start()
+
+ // wait for the receiver to start before publishing data, or we risk failing
+ // the test nondeterministically. See SPARK-4631
+ waitForReceiverToStart()
+
publishData(sendMessage)
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
assert(sendMessage.equals(receiveMessage(0)))
@@ -121,8 +130,14 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
val message: MqttMessage = new MqttMessage(data.getBytes("utf-8"))
message.setQos(1)
message.setRetained(true)
- for (i <- 0 to 100) {
- msgTopic.publish(message)
+
+ for (i <- 0 to 10) {
+ try {
+ msgTopic.publish(message)
+ } catch {
+ case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
+ Thread.sleep(50) // wait for Spark streaming to consume something from the message queue
+ }
}
}
} finally {
@@ -131,4 +146,18 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
client = null
}
}
+
+ /**
+ * Block until at least one receiver has started or timeout occurs.
+ */
+ private def waitForReceiverToStart() = {
+ val latch = new CountDownLatch(1)
+ ssc.addStreamingListener(new StreamingListener {
+ override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
+ latch.countDown()
+ }
+ })
+
+ assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
+ }
}