From 8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 01:57:44 -0800 Subject: Added accumulators to PySpark --- python/pyspark/context.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) (limited to 'python/pyspark/context.py') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..1e2f845f9c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -2,6 +2,8 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched @@ -22,6 +24,7 @@ class SparkContext(object): _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition + _next_accum_id = 0 def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -52,6 +55,14 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self.jvm.java.util.ArrayList(), + self.jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -74,6 +85,8 @@ class SparkContext(object): def __del__(self): if self._jsc: self._jsc.stop() + if self._accumulatorServer: + self._accumulatorServer.shutdown() def stop(self): """ @@ -129,6 +142,31 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an C{Accumulator} with the given initial value, using a given + AccumulatorParam helper object to define how to add values of the data + type if provided. Default AccumulatorParams are used for integers and + floating-point numbers if you do not provide one. For other types, the + AccumulatorParam must implement two methods: + - C{zero(value)}: provide a "zero value" for the type, compatible in + dimensions with the provided C{value} (e.g., a zero vector). + - C{addInPlace(val1, val2)}: add two values of the accumulator's data + type, returning a new value; for efficiency, can also update C{val1} + in place and return it. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ Add a file to be downloaded into the working directory of this Spark -- cgit v1.2.3