aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/streaming/dstream.py6
-rw-r--r--python/pyspark/streaming/tests.py11
2 files changed, 14 insertions, 3 deletions
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
index 698336cfce..acec850f02 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -524,8 +524,8 @@ class DStream(object):
`invFunc` can be None, then it will reduce all the RDDs in window, could be slower
than having `invFunc`.
- @param reduceFunc: associative reduce function
- @param invReduceFunc: inverse function of `reduceFunc`
+ @param func: associative reduce function
+ @param invFunc: inverse function of `reduceFunc`
@param windowDuration: width of the window; must be a multiple of this DStream's
batching interval
@param slideDuration: sliding interval of the window (i.e., the interval after which
@@ -556,7 +556,7 @@ class DStream(object):
if kv[1] is not None else kv[0])
jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer)
- if invReduceFunc:
+ if invFunc:
jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer)
else:
jinvReduceFunc = None
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 0bcd1f1553..3403f6d20d 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -582,6 +582,17 @@ class WindowFunctionTests(PySparkStreamingTestCase):
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
+ def test_reduce_by_key_and_window_with_none_invFunc(self):
+ input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+ def func(dstream):
+ return dstream.map(lambda x: (x, 1))\
+ .reduceByKeyAndWindow(operator.add, None, 5, 1)\
+ .filter(lambda kv: kv[1] > 0).count()
+
+ expected = [[2], [4], [6], [6], [6], [6]]
+ self._test_func(input, func, expected)
+
class StreamingContextTests(PySparkStreamingTestCase):