diff options
author | Ewen Cheslack-Postava <me@ewencp.org> | 2013-10-19 19:55:39 -0700 |
---|---|---|
committer | Ewen Cheslack-Postava <me@ewencp.org> | 2013-10-19 19:55:39 -0700 |
commit | 7eaa56de7f0253869fa85d4366f1048386af477e (patch) | |
tree | 0cc31f0653e2bac5282f568dc6eab9632c347b3b /python | |
parent | 6511bbe2adeb5e361fb3c31bbda245eeb890647a (diff) | |
download | spark-7eaa56de7f0253869fa85d4366f1048386af477e.tar.gz spark-7eaa56de7f0253869fa85d4366f1048386af477e.tar.bz2 spark-7eaa56de7f0253869fa85d4366f1048386af477e.zip |
Add an add() method to pyspark accumulators.
Add a regular method for adding a term to accumulators in
pyspark. Currently if you have a non-global accumulator, adding to it
is awkward. The += operator can't be used for non-global accumulators
captured via closure because it's involves an assignment. The only way
to do it is using __iadd__ directly.
Adding this method lets you write code like this:
def main():
sc = SparkContext()
accum = sc.accumulator(0)
rdd = sc.parallelize([1,2,3])
def f(x):
accum.add(x)
rdd.foreach(f)
print accum.value
where using accum += x instead would have caused UnboundLocalError
exceptions in workers. Currently it would have to be written as
accum.__iadd__(x).
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/accumulators.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index d367f91967..da3d96689a 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -42,6 +42,13 @@ >>> a.value 13 +>>> b = sc.accumulator(0) +>>> def g(x): +... b.add(x) +>>> rdd.foreach(g) +>>> b.value +6 + >>> from pyspark.accumulators import AccumulatorParam >>> class VectorAccumulatorParam(AccumulatorParam): ... def zero(self, value): @@ -139,9 +146,13 @@ class Accumulator(object): raise Exception("Accumulator.value cannot be accessed inside tasks") self._value = value + def add(self, term): + """Adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + def __iadd__(self, term): """The += operator; adds a term to this accumulator's value""" - self._value = self.accum_param.addInPlace(self._value, term) + self.add(term) return self def __str__(self): |