aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/streaming/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/streaming/tests.py')
-rw-r--r--python/pyspark/streaming/tests.py266
1 files changed, 3 insertions, 263 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index f4bbb1b128..eb4696c55d 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -45,8 +45,6 @@ from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition
-from pyspark.streaming.flume import FlumeUtils
-from pyspark.streaming.mqtt import MQTTUtils
from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream
from pyspark.streaming.listener import StreamingListener
@@ -1262,207 +1260,6 @@ class KafkaStreamTests(PySparkStreamingTestCase):
self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream)
-class FlumeStreamTests(PySparkStreamingTestCase):
- timeout = 20 # seconds
- duration = 1
-
- def setUp(self):
- super(FlumeStreamTests, self).setUp()
- self._utils = self.ssc._jvm.org.apache.spark.streaming.flume.FlumeTestUtils()
-
- def tearDown(self):
- if self._utils is not None:
- self._utils.close()
- self._utils = None
-
- super(FlumeStreamTests, self).tearDown()
-
- def _startContext(self, n, compressed):
- # Start the StreamingContext and also collect the result
- dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(),
- enableDecompression=compressed)
- result = []
-
- def get_output(_, rdd):
- for event in rdd.collect():
- if len(result) < n:
- result.append(event)
- dstream.foreachRDD(get_output)
- self.ssc.start()
- return result
-
- def _validateResult(self, input, result):
- # Validate both the header and the body
- header = {"test": "header"}
- self.assertEqual(len(input), len(result))
- for i in range(0, len(input)):
- self.assertEqual(header, result[i][0])
- self.assertEqual(input[i], result[i][1])
-
- def _writeInput(self, input, compressed):
- # Try to write input to the receiver until success or timeout
- start_time = time.time()
- while True:
- try:
- self._utils.writeInput(input, compressed)
- break
- except:
- if time.time() - start_time < self.timeout:
- time.sleep(0.01)
- else:
- raise
-
- def test_flume_stream(self):
- input = [str(i) for i in range(1, 101)]
- result = self._startContext(len(input), False)
- self._writeInput(input, False)
- self.wait_for(result, len(input))
- self._validateResult(input, result)
-
- def test_compressed_flume_stream(self):
- input = [str(i) for i in range(1, 101)]
- result = self._startContext(len(input), True)
- self._writeInput(input, True)
- self.wait_for(result, len(input))
- self._validateResult(input, result)
-
-
-class FlumePollingStreamTests(PySparkStreamingTestCase):
- timeout = 20 # seconds
- duration = 1
- maxAttempts = 5
-
- def setUp(self):
- self._utils = self.sc._jvm.org.apache.spark.streaming.flume.PollingFlumeTestUtils()
-
- def tearDown(self):
- if self._utils is not None:
- self._utils.close()
- self._utils = None
-
- def _writeAndVerify(self, ports):
- # Set up the streaming context and input streams
- ssc = StreamingContext(self.sc, self.duration)
- try:
- addresses = [("localhost", port) for port in ports]
- dstream = FlumeUtils.createPollingStream(
- ssc,
- addresses,
- maxBatchSize=self._utils.eventsPerBatch(),
- parallelism=5)
- outputBuffer = []
-
- def get_output(_, rdd):
- for e in rdd.collect():
- outputBuffer.append(e)
-
- dstream.foreachRDD(get_output)
- ssc.start()
- self._utils.sendDatAndEnsureAllDataHasBeenReceived()
-
- self.wait_for(outputBuffer, self._utils.getTotalEvents())
- outputHeaders = [event[0] for event in outputBuffer]
- outputBodies = [event[1] for event in outputBuffer]
- self._utils.assertOutput(outputHeaders, outputBodies)
- finally:
- ssc.stop(False)
-
- def _testMultipleTimes(self, f):
- attempt = 0
- while True:
- try:
- f()
- break
- except:
- attempt += 1
- if attempt >= self.maxAttempts:
- raise
- else:
- import traceback
- traceback.print_exc()
-
- def _testFlumePolling(self):
- try:
- port = self._utils.startSingleSink()
- self._writeAndVerify([port])
- self._utils.assertChannelsAreEmpty()
- finally:
- self._utils.close()
-
- def _testFlumePollingMultipleHosts(self):
- try:
- port = self._utils.startSingleSink()
- self._writeAndVerify([port])
- self._utils.assertChannelsAreEmpty()
- finally:
- self._utils.close()
-
- def test_flume_polling(self):
- self._testMultipleTimes(self._testFlumePolling)
-
- def test_flume_polling_multiple_hosts(self):
- self._testMultipleTimes(self._testFlumePollingMultipleHosts)
-
-
-class MQTTStreamTests(PySparkStreamingTestCase):
- timeout = 20 # seconds
- duration = 1
-
- def setUp(self):
- super(MQTTStreamTests, self).setUp()
- self._MQTTTestUtils = self.ssc._jvm.org.apache.spark.streaming.mqtt.MQTTTestUtils()
- self._MQTTTestUtils.setup()
-
- def tearDown(self):
- if self._MQTTTestUtils is not None:
- self._MQTTTestUtils.teardown()
- self._MQTTTestUtils = None
-
- super(MQTTStreamTests, self).tearDown()
-
- def _randomTopic(self):
- return "topic-%d" % random.randint(0, 10000)
-
- def _startContext(self, topic):
- # Start the StreamingContext and also collect the result
- stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
- result = []
-
- def getOutput(_, rdd):
- for data in rdd.collect():
- result.append(data)
-
- stream.foreachRDD(getOutput)
- self.ssc.start()
- return result
-
- def test_mqtt_stream(self):
- """Test the Python MQTT stream API."""
- sendData = "MQTT demo for spark streaming"
- topic = self._randomTopic()
- result = self._startContext(topic)
-
- def retry():
- self._MQTTTestUtils.publishData(topic, sendData)
- # Because "publishData" sends duplicate messages, here we should use > 0
- self.assertTrue(len(result) > 0)
- self.assertEqual(sendData, result[0])
-
- # Retry it because we don't know when the receiver will start.
- self._retry_or_timeout(retry)
-
- def _retry_or_timeout(self, test_func):
- start_time = time.time()
- while True:
- try:
- test_func()
- break
- except:
- if time.time() - start_time > self.timeout:
- raise
- time.sleep(0.01)
-
-
class KinesisStreamTests(PySparkStreamingTestCase):
def test_kinesis_stream_api(self):
@@ -1551,57 +1348,6 @@ def search_kafka_assembly_jar():
return jars[0]
-def search_flume_assembly_jar():
- SPARK_HOME = os.environ["SPARK_HOME"]
- flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly")
- jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly")
- if not jars:
- raise Exception(
- ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) +
- "You need to build Spark with "
- "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or "
- "'build/mvn package' before running this test.")
- elif len(jars) > 1:
- raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please "
- "remove all but one") % (", ".join(jars)))
- else:
- return jars[0]
-
-
-def search_mqtt_assembly_jar():
- SPARK_HOME = os.environ["SPARK_HOME"]
- mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly")
- jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly")
- if not jars:
- raise Exception(
- ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) +
- "You need to build Spark with "
- "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or "
- "'build/mvn package' before running this test")
- elif len(jars) > 1:
- raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please "
- "remove all but one") % (", ".join(jars)))
- else:
- return jars[0]
-
-
-def search_mqtt_test_jar():
- SPARK_HOME = os.environ["SPARK_HOME"]
- mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt")
- jars = glob.glob(
- os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar"))
- if not jars:
- raise Exception(
- ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) +
- "You need to build Spark with "
- "'build/sbt assembly/assembly streaming-mqtt/test:assembly'")
- elif len(jars) > 1:
- raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please "
- "remove all but one") % (", ".join(jars)))
- else:
- return jars[0]
-
-
def search_kinesis_asl_assembly_jar():
SPARK_HOME = os.environ["SPARK_HOME"]
kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly")
@@ -1622,24 +1368,18 @@ are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1'
if __name__ == "__main__":
from pyspark.streaming.tests import *
kafka_assembly_jar = search_kafka_assembly_jar()
- flume_assembly_jar = search_flume_assembly_jar()
- mqtt_assembly_jar = search_mqtt_assembly_jar()
- mqtt_test_jar = search_mqtt_test_jar()
kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar()
if kinesis_asl_assembly_jar is None:
kinesis_jar_present = False
- jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar,
- mqtt_test_jar)
+ jars = kafka_assembly_jar
else:
kinesis_jar_present = True
- jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar,
- mqtt_test_jar, kinesis_asl_assembly_jar)
+ jars = "%s,%s" % (kafka_assembly_jar, kinesis_asl_assembly_jar)
os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests,
- KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests,
- StreamingListenerTests]
+ KafkaStreamTests, StreamingListenerTests]
if kinesis_jar_present is True:
testcases.append(KinesisStreamTests)