aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 01:57:44 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 01:57:44 -0800
commit8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa (patch)
treea4b5d1891501e99f44b4d797bee3a10504e0b2fd /python/pyspark/context.py
parent54c0f9f185576e9b844fa8f81ca410f188daa51c (diff)
downloadspark-8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa.tar.gz
spark-8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa.tar.bz2
spark-8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa.zip
Added accumulators to PySpark
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index e486f206b0..1e2f845f9c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -2,6 +2,8 @@ import os
import atexit
from tempfile import NamedTemporaryFile
+from pyspark import accumulators
+from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
@@ -22,6 +24,7 @@ class SparkContext(object):
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
+ _next_accum_id = 0
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -52,6 +55,14 @@ class SparkContext(object):
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
+ # Create a single Accumulator in Java that we'll send all our updates through;
+ # they will be passed back to us through a TCP server
+ self._accumulatorServer = accumulators._start_update_server()
+ (host, port) = self._accumulatorServer.server_address
+ self._javaAccumulator = self._jsc.accumulator(
+ self.jvm.java.util.ArrayList(),
+ self.jvm.PythonAccumulatorParam(host, port))
+
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
@@ -74,6 +85,8 @@ class SparkContext(object):
def __del__(self):
if self._jsc:
self._jsc.stop()
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
def stop(self):
"""
@@ -129,6 +142,31 @@ class SparkContext(object):
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
+ def accumulator(self, value, accum_param=None):
+ """
+ Create an C{Accumulator} with the given initial value, using a given
+ AccumulatorParam helper object to define how to add values of the data
+ type if provided. Default AccumulatorParams are used for integers and
+ floating-point numbers if you do not provide one. For other types, the
+ AccumulatorParam must implement two methods:
+ - C{zero(value)}: provide a "zero value" for the type, compatible in
+ dimensions with the provided C{value} (e.g., a zero vector).
+ - C{addInPlace(val1, val2)}: add two values of the accumulator's data
+ type, returning a new value; for efficiency, can also update C{val1}
+ in place and return it.
+ """
+ if accum_param == None:
+ if isinstance(value, int):
+ accum_param = accumulators.INT_ACCUMULATOR_PARAM
+ elif isinstance(value, float):
+ accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
+ elif isinstance(value, complex):
+ accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
+ else:
+ raise Exception("No default accumulator param for type %s" % type(value))
+ SparkContext._next_accum_id += 1
+ return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
+
def addFile(self, path):
"""
Add a file to be downloaded into the working directory of this Spark