aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/accumulators.py
diff options
context:
space:
mode:
authorEwen Cheslack-Postava <me@ewencp.org>2013-10-19 19:55:39 -0700
committerEwen Cheslack-Postava <me@ewencp.org>2013-10-19 19:55:39 -0700
commit7eaa56de7f0253869fa85d4366f1048386af477e (patch)
tree0cc31f0653e2bac5282f568dc6eab9632c347b3b /python/pyspark/accumulators.py
parent6511bbe2adeb5e361fb3c31bbda245eeb890647a (diff)
downloadspark-7eaa56de7f0253869fa85d4366f1048386af477e.tar.gz
spark-7eaa56de7f0253869fa85d4366f1048386af477e.tar.bz2
spark-7eaa56de7f0253869fa85d4366f1048386af477e.zip
Add an add() method to pyspark accumulators.
Add a regular method for adding a term to accumulators in pyspark. Currently if you have a non-global accumulator, adding to it is awkward. The += operator can't be used for non-global accumulators captured via closure because it's involves an assignment. The only way to do it is using __iadd__ directly. Adding this method lets you write code like this: def main(): sc = SparkContext() accum = sc.accumulator(0) rdd = sc.parallelize([1,2,3]) def f(x): accum.add(x) rdd.foreach(f) print accum.value where using accum += x instead would have caused UnboundLocalError exceptions in workers. Currently it would have to be written as accum.__iadd__(x).
Diffstat (limited to 'python/pyspark/accumulators.py')
-rw-r--r--python/pyspark/accumulators.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index d367f91967..da3d96689a 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -42,6 +42,13 @@
>>> a.value
13
+>>> b = sc.accumulator(0)
+>>> def g(x):
+... b.add(x)
+>>> rdd.foreach(g)
+>>> b.value
+6
+
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
@@ -139,9 +146,13 @@ class Accumulator(object):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
+ def add(self, term):
+ """Adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
- self._value = self.accum_param.addInPlace(self._value, term)
+ self.add(term)
return self
def __str__(self):