aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/streaming/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/streaming/tests.py')
-rw-r--r--python/pyspark/streaming/tests.py63
1 files changed, 39 insertions, 24 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 33f958a601..5fa1e5ef08 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -16,14 +16,23 @@
#
import os
+import sys
from itertools import chain
import time
import operator
-import unittest
import tempfile
import struct
from functools import reduce
+if sys.version_info[:2] <= (2, 6):
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
+else:
+ import unittest
+
from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.kafka import KafkaUtils
@@ -31,19 +40,25 @@ from pyspark.streaming.kafka import KafkaUtils
class PySparkStreamingTestCase(unittest.TestCase):
- timeout = 20 # seconds
- duration = 1
+ timeout = 4 # seconds
+ duration = .2
- def setUp(self):
- class_name = self.__class__.__name__
+ @classmethod
+ def setUpClass(cls):
+ class_name = cls.__name__
conf = SparkConf().set("spark.default.parallelism", 1)
- self.sc = SparkContext(appName=class_name, conf=conf)
- self.sc.setCheckpointDir("/tmp")
- # TODO: decrease duration to speed up tests
+ cls.sc = SparkContext(appName=class_name, conf=conf)
+ cls.sc.setCheckpointDir("/tmp")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sc.stop()
+
+ def setUp(self):
self.ssc = StreamingContext(self.sc, self.duration)
def tearDown(self):
- self.ssc.stop()
+ self.ssc.stop(False)
def wait_for(self, result, n):
start_time = time.time()
@@ -363,13 +378,13 @@ class BasicOperationTests(PySparkStreamingTestCase):
class WindowFunctionTests(PySparkStreamingTestCase):
- timeout = 20
+ timeout = 5
def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.window(3, 1).count()
+ return dstream.window(.6, .2).count()
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -378,7 +393,7 @@ class WindowFunctionTests(PySparkStreamingTestCase):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.countByWindow(3, 1)
+ return dstream.countByWindow(.6, .2)
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -387,7 +402,7 @@ class WindowFunctionTests(PySparkStreamingTestCase):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByWindow(5, 1)
+ return dstream.countByWindow(1, .2)
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)
@@ -396,7 +411,7 @@ class WindowFunctionTests(PySparkStreamingTestCase):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByValueAndWindow(5, 1)
+ return dstream.countByValueAndWindow(1, .2)
expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
self._test_func(input, func, expected)
@@ -405,7 +420,7 @@ class WindowFunctionTests(PySparkStreamingTestCase):
input = [[('a', i)] for i in range(5)]
def func(dstream):
- return dstream.groupByKeyAndWindow(3, 1).mapValues(list)
+ return dstream.groupByKeyAndWindow(.6, .2).mapValues(list)
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
@@ -436,8 +451,8 @@ class StreamingContextTests(PySparkStreamingTestCase):
def test_stop_multiple_times(self):
self._add_input_stream()
self.ssc.start()
- self.ssc.stop()
- self.ssc.stop()
+ self.ssc.stop(False)
+ self.ssc.stop(False)
def test_queue_stream(self):
input = [list(range(i + 1)) for i in range(3)]
@@ -495,10 +510,7 @@ class StreamingContextTests(PySparkStreamingTestCase):
self.assertEqual([2, 3, 1], self._take(dstream, 3))
-class CheckpointTests(PySparkStreamingTestCase):
-
- def setUp(self):
- pass
+class CheckpointTests(unittest.TestCase):
def test_get_or_create(self):
inputd = tempfile.mkdtemp()
@@ -518,12 +530,12 @@ class CheckpointTests(PySparkStreamingTestCase):
return ssc
cpd = tempfile.mkdtemp("test_streaming_cps")
- self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
def check_output(n):
while not os.listdir(outputd):
- time.sleep(0.1)
+ time.sleep(0.01)
time.sleep(1) # make sure mtime is larger than the previous one
with open(os.path.join(inputd, str(n)), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
@@ -553,12 +565,15 @@ class CheckpointTests(PySparkStreamingTestCase):
ssc.stop(True, True)
time.sleep(1)
- self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
check_output(3)
+ ssc.stop(True, True)
class KafkaStreamTests(PySparkStreamingTestCase):
+ timeout = 20 # seconds
+ duration = 1
def setUp(self):
super(KafkaStreamTests, self).setUp()