diff options
Diffstat (limited to 'python/pyspark/accumulators.py')
-rw-r--r-- | python/pyspark/accumulators.py | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index d367f91967..2204e9c9ca 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -42,6 +42,13 @@ >>> a.value 13 +>>> b = sc.accumulator(0) +>>> def g(x): +... b.add(x) +>>> rdd.foreach(g) +>>> b.value +6 + >>> from pyspark.accumulators import AccumulatorParam >>> class VectorAccumulatorParam(AccumulatorParam): ... def zero(self, value): @@ -83,8 +90,10 @@ import struct import SocketServer import threading from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import read_int, read_with_length, load_pickle +from pyspark.serializers import read_int, PickleSerializer + +pickleSer = PickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. @@ -139,9 +148,13 @@ class Accumulator(object): raise Exception("Accumulator.value cannot be accessed inside tasks") self._value = value + def add(self, term): + """Adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + def __iadd__(self, term): """The += operator; adds a term to this accumulator's value""" - self._value = self.accum_param.addInPlace(self._value, term) + self.add(term) return self def __str__(self): @@ -200,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): - (aid, update) = load_pickle(read_with_length(self.rfile)) + (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) |