diff options
Diffstat (limited to 'python/pyspark/streaming/tests.py')
-rw-r--r-- | python/pyspark/streaming/tests.py | 266 |
1 files changed, 3 insertions, 263 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index f4bbb1b128..eb4696c55d 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -45,8 +45,6 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition -from pyspark.streaming.flume import FlumeUtils -from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream from pyspark.streaming.listener import StreamingListener @@ -1262,207 +1260,6 @@ class KafkaStreamTests(PySparkStreamingTestCase): self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream) -class FlumeStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - - def setUp(self): - super(FlumeStreamTests, self).setUp() - self._utils = self.ssc._jvm.org.apache.spark.streaming.flume.FlumeTestUtils() - - def tearDown(self): - if self._utils is not None: - self._utils.close() - self._utils = None - - super(FlumeStreamTests, self).tearDown() - - def _startContext(self, n, compressed): - # Start the StreamingContext and also collect the result - dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), - enableDecompression=compressed) - result = [] - - def get_output(_, rdd): - for event in rdd.collect(): - if len(result) < n: - result.append(event) - dstream.foreachRDD(get_output) - self.ssc.start() - return result - - def _validateResult(self, input, result): - # Validate both the header and the body - header = {"test": "header"} - self.assertEqual(len(input), len(result)) - for i in range(0, len(input)): - self.assertEqual(header, result[i][0]) - self.assertEqual(input[i], result[i][1]) - - def _writeInput(self, input, compressed): - # Try to write input to the receiver until success or timeout - start_time = time.time() - while True: - try: - self._utils.writeInput(input, compressed) - break - except: - if time.time() - start_time < self.timeout: - time.sleep(0.01) - else: - raise - - def test_flume_stream(self): - input = [str(i) for i in range(1, 101)] - result = self._startContext(len(input), False) - self._writeInput(input, False) - self.wait_for(result, len(input)) - self._validateResult(input, result) - - def test_compressed_flume_stream(self): - input = [str(i) for i in range(1, 101)] - result = self._startContext(len(input), True) - self._writeInput(input, True) - self.wait_for(result, len(input)) - self._validateResult(input, result) - - -class FlumePollingStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - maxAttempts = 5 - - def setUp(self): - self._utils = self.sc._jvm.org.apache.spark.streaming.flume.PollingFlumeTestUtils() - - def tearDown(self): - if self._utils is not None: - self._utils.close() - self._utils = None - - def _writeAndVerify(self, ports): - # Set up the streaming context and input streams - ssc = StreamingContext(self.sc, self.duration) - try: - addresses = [("localhost", port) for port in ports] - dstream = FlumeUtils.createPollingStream( - ssc, - addresses, - maxBatchSize=self._utils.eventsPerBatch(), - parallelism=5) - outputBuffer = [] - - def get_output(_, rdd): - for e in rdd.collect(): - outputBuffer.append(e) - - dstream.foreachRDD(get_output) - ssc.start() - self._utils.sendDatAndEnsureAllDataHasBeenReceived() - - self.wait_for(outputBuffer, self._utils.getTotalEvents()) - outputHeaders = [event[0] for event in outputBuffer] - outputBodies = [event[1] for event in outputBuffer] - self._utils.assertOutput(outputHeaders, outputBodies) - finally: - ssc.stop(False) - - def _testMultipleTimes(self, f): - attempt = 0 - while True: - try: - f() - break - except: - attempt += 1 - if attempt >= self.maxAttempts: - raise - else: - import traceback - traceback.print_exc() - - def _testFlumePolling(self): - try: - port = self._utils.startSingleSink() - self._writeAndVerify([port]) - self._utils.assertChannelsAreEmpty() - finally: - self._utils.close() - - def _testFlumePollingMultipleHosts(self): - try: - port = self._utils.startSingleSink() - self._writeAndVerify([port]) - self._utils.assertChannelsAreEmpty() - finally: - self._utils.close() - - def test_flume_polling(self): - self._testMultipleTimes(self._testFlumePolling) - - def test_flume_polling_multiple_hosts(self): - self._testMultipleTimes(self._testFlumePollingMultipleHosts) - - -class MQTTStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - - def setUp(self): - super(MQTTStreamTests, self).setUp() - self._MQTTTestUtils = self.ssc._jvm.org.apache.spark.streaming.mqtt.MQTTTestUtils() - self._MQTTTestUtils.setup() - - def tearDown(self): - if self._MQTTTestUtils is not None: - self._MQTTTestUtils.teardown() - self._MQTTTestUtils = None - - super(MQTTStreamTests, self).tearDown() - - def _randomTopic(self): - return "topic-%d" % random.randint(0, 10000) - - def _startContext(self, topic): - # Start the StreamingContext and also collect the result - stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) - result = [] - - def getOutput(_, rdd): - for data in rdd.collect(): - result.append(data) - - stream.foreachRDD(getOutput) - self.ssc.start() - return result - - def test_mqtt_stream(self): - """Test the Python MQTT stream API.""" - sendData = "MQTT demo for spark streaming" - topic = self._randomTopic() - result = self._startContext(topic) - - def retry(): - self._MQTTTestUtils.publishData(topic, sendData) - # Because "publishData" sends duplicate messages, here we should use > 0 - self.assertTrue(len(result) > 0) - self.assertEqual(sendData, result[0]) - - # Retry it because we don't know when the receiver will start. - self._retry_or_timeout(retry) - - def _retry_or_timeout(self, test_func): - start_time = time.time() - while True: - try: - test_func() - break - except: - if time.time() - start_time > self.timeout: - raise - time.sleep(0.01) - - class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -1551,57 +1348,6 @@ def search_kafka_assembly_jar(): return jars[0] -def search_flume_assembly_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") - jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly") - if not jars: - raise Exception( - ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -def search_mqtt_assembly_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") - jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly") - if not jars: - raise Exception( - ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " - "'build/mvn package' before running this test") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -def search_mqtt_test_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") - jars = glob.glob( - os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) - if not jars: - raise Exception( - ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - def search_kinesis_asl_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") @@ -1622,24 +1368,18 @@ are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' if __name__ == "__main__": from pyspark.streaming.tests import * kafka_assembly_jar = search_kafka_assembly_jar() - flume_assembly_jar = search_flume_assembly_jar() - mqtt_assembly_jar = search_mqtt_assembly_jar() - mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() if kinesis_asl_assembly_jar is None: kinesis_jar_present = False - jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, - mqtt_test_jar) + jars = kafka_assembly_jar else: kinesis_jar_present = True - jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, - mqtt_test_jar, kinesis_asl_assembly_jar) + jars = "%s,%s" % (kafka_assembly_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests, - StreamingListenerTests] + KafkaStreamTests, StreamingListenerTests] if kinesis_jar_present is True: testcases.append(KinesisStreamTests) |