diff options
Diffstat (limited to 'python/pyspark/streaming/tests.py')
-rw-r--r-- | python/pyspark/streaming/tests.py | 63 |
1 files changed, 39 insertions, 24 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33f958a601..5fa1e5ef08 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -16,14 +16,23 @@ # import os +import sys from itertools import chain import time import operator -import unittest import tempfile import struct from functools import reduce +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import KafkaUtils @@ -31,19 +40,25 @@ from pyspark.streaming.kafka import KafkaUtils class PySparkStreamingTestCase(unittest.TestCase): - timeout = 20 # seconds - duration = 1 + timeout = 4 # seconds + duration = .2 - def setUp(self): - class_name = self.__class__.__name__ + @classmethod + def setUpClass(cls): + class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) - self.sc = SparkContext(appName=class_name, conf=conf) - self.sc.setCheckpointDir("/tmp") - # TODO: decrease duration to speed up tests + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir("/tmp") + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop() + self.ssc.stop(False) def wait_for(self, result, n): start_time = time.time() @@ -363,13 +378,13 @@ class BasicOperationTests(PySparkStreamingTestCase): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 20 + timeout = 5 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(3, 1).count() + return dstream.window(.6, .2).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -378,7 +393,7 @@ class WindowFunctionTests(PySparkStreamingTestCase): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(3, 1) + return dstream.countByWindow(.6, .2) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -387,7 +402,7 @@ class WindowFunctionTests(PySparkStreamingTestCase): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(5, 1) + return dstream.countByWindow(1, .2) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -396,7 +411,7 @@ class WindowFunctionTests(PySparkStreamingTestCase): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(5, 1) + return dstream.countByValueAndWindow(1, .2) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -405,7 +420,7 @@ class WindowFunctionTests(PySparkStreamingTestCase): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] @@ -436,8 +451,8 @@ class StreamingContextTests(PySparkStreamingTestCase): def test_stop_multiple_times(self): self._add_input_stream() self.ssc.start() - self.ssc.stop() - self.ssc.stop() + self.ssc.stop(False) + self.ssc.stop(False) def test_queue_stream(self): input = [list(range(i + 1)) for i in range(3)] @@ -495,10 +510,7 @@ class StreamingContextTests(PySparkStreamingTestCase): self.assertEqual([2, 3, 1], self._take(dstream, 3)) -class CheckpointTests(PySparkStreamingTestCase): - - def setUp(self): - pass +class CheckpointTests(unittest.TestCase): def test_get_or_create(self): inputd = tempfile.mkdtemp() @@ -518,12 +530,12 @@ class CheckpointTests(PySparkStreamingTestCase): return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() def check_output(n): while not os.listdir(outputd): - time.sleep(0.1) + time.sleep(0.01) time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) @@ -553,12 +565,15 @@ class CheckpointTests(PySparkStreamingTestCase): ssc.stop(True, True) time.sleep(1) - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() check_output(3) + ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 def setUp(self): super(KafkaStreamTests, self).setUp() |