diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/streaming/dstream.py | 6 | ||||
-rw-r--r-- | python/pyspark/streaming/tests.py | 11 |
2 files changed, 14 insertions, 3 deletions
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 698336cfce..acec850f02 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -524,8 +524,8 @@ class DStream(object): `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -556,7 +556,7 @@ class DStream(object): if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0bcd1f1553..3403f6d20d 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -582,6 +582,17 @@ class WindowFunctionTests(PySparkStreamingTestCase): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): |