diff options
Diffstat (limited to 'external/mqtt')
5 files changed, 214 insertions, 103 deletions
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e57817..69b309876a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -78,5 +78,33 @@ <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + + <plugins> + <!-- Assemble a jar with test dependencies for Python tests --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-assembly-plugin</artifactId> + <executions> + <execution> + <id>test-jar-with-dependencies</id> + <phase>package</phase> + <goals> + <goal>single</goal> + </goals> + <configuration> + <!-- Make sure the file path is same as the sbt build --> + <finalName>spark-streaming-mqtt-test-${project.version}</finalName> + <outputDirectory>${project.build.directory}/scala-${scala.binary.version}/</outputDirectory> + <appendAssemblyId>false</appendAssemblyId> + <!-- Don't publish it since it's only for Python tests --> + <attach>false</attach> + <descriptors> + <descriptor>src/main/assembly/assembly.xml</descriptor> + </descriptors> + </configuration> + </execution> + </executions> + </plugin> + </plugins> </build> </project> diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 0000000000..ecab5b360e --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> +<assembly> + <id>test-jar-with-dependencies</id> + <formats> + <format>jar</format> + </formats> + <includeBaseDirectory>false</includeBaseDirectory> + + <fileSets> + <fileSet> + <directory>${project.build.directory}/scala-${scala.binary.version}/test-classes</directory> + <outputDirectory>/</outputDirectory> + </fileSet> + </fileSets> + + <dependencySets> + <dependencySet> + <useTransitiveDependencies>true</useTransitiveDependencies> + <scope>test</scope> + <unpack>true</unpack> + <excludes> + <exclude>org.apache.hadoop:*:jar</exclude> + <exclude>org.apache.zookeeper:*:jar</exclude> + <exclude>org.apache.avro:*:jar</exclude> + </excludes> + </dependencySet> + </dependencySets> + +</assembly> diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56b..38a1114863 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -74,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa786..a6a9249db8 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,46 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 0000000000..1a371b7008 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} |