""" >>> from pyspark.context import SparkContext >>> sc = SparkContext('local', 'test') >>> a = sc.accumulator(1) >>> a.value 1 >>> a.value = 2 >>> a.value 2 >>> a += 5 >>> a.value 7 >>> sc.accumulator(1.0).value 1.0 >>> sc.accumulator(1j).value 1j >>> rdd = sc.parallelize([1,2,3]) >>> def f(x): ... global a ... a += x >>> rdd.foreach(f) >>> a.value 13 >>> from pyspark.accumulators import AccumulatorParam >>> class VectorAccumulatorParam(AccumulatorParam): ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): ... for i in xrange(len(val1)): ... val1[i] += val2[i] ... return val1 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) >>> va.value [1.0, 2.0, 3.0] >>> def g(x): ... global va ... va += [x] * 3 >>> rdd.foreach(g) >>> va.value [7.0, 8.0, 9.0] >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError:... >>> def h(x): ... global a ... a.value = 7 >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError:... >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Exception:... """ import struct import SocketServer import threading from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, read_with_length, load_pickle # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. _accumulatorRegistry = {} def _deserialize_accumulator(aid, zero_value, accum_param): from pyspark.accumulators import _accumulatorRegistry accum = Accumulator(aid, zero_value, accum_param) accum._deserialized = True _accumulatorRegistry[aid] = accum return accum class Accumulator(object): """ A shared variable that can be accumulated, i.e., has a commutative and associative "add" operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} operator, but only the driver program is allowed to access its value, using C{value}. Updates from the workers get propagated automatically to the driver program. While C{SparkContext} supports accumulators for primitive data types like C{int} and C{float}, users can also define accumulators for custom types by providing a custom L{AccumulatorParam} object. Refer to the doctest of this module for an example. """ def __init__(self, aid, value, accum_param): """Create a new Accumulator with a given initial value and AccumulatorParam object""" from pyspark.accumulators import _accumulatorRegistry self.aid = aid self.accum_param = accum_param self._value = value self._deserialized = False _accumulatorRegistry[aid] = self def __reduce__(self): """Custom serialization; saves the zero value from our AccumulatorParam""" param = self.accum_param return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) @property def value(self): """Get the accumulator's value; only usable in driver program""" if self._deserialized: raise Exception("Accumulator.value cannot be accessed inside tasks") return self._value @value.setter def value(self, value): """Sets the accumulator's value; only usable in driver program""" if self._deserialized: raise Exception("Accumulator.value cannot be accessed inside tasks") self._value = value def __iadd__(self, term): """The += operator; adds a term to this accumulator's value""" self._value = self.accum_param.addInPlace(self._value, term) return self def __str__(self): return str(self._value) def __repr__(self): return "Accumulator" % (self.aid, self._value) class AccumulatorParam(object): """ Helper object that defines how to accumulate values of a given type. """ def zero(self, value): """ Provide a "zero value" for the type, compatible in dimensions with the provided C{value} (e.g., a zero vector) """ raise NotImplementedError def addInPlace(self, value1, value2): """ Add two values of the accumulator's data type, returning a new value; for efficiency, can also update C{value1} in place and return it. """ raise NotImplementedError class AddingAccumulatorParam(AccumulatorParam): """ An AccumulatorParam that uses the + operators to add values. Designed for simple types such as integers, floats, and lists. Requires the zero value for the underlying type as a parameter. """ def __init__(self, zero_value): self.zero_value = zero_value def zero(self, value): return self.zero_value def addInPlace(self, value1, value2): value1 += value2 return value1 # Singleton accumulator params for some standard types INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): (aid, update) = load_pickle(read_with_length(self.rfile)) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() return server