aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index e486f206b0..dcbed37270 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -2,6 +2,8 @@ import os
import atexit
from tempfile import NamedTemporaryFile
+from pyspark import accumulators
+from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
@@ -22,6 +24,7 @@ class SparkContext(object):
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
+ _next_accum_id = 0
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -52,6 +55,14 @@ class SparkContext(object):
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
+ # Create a single Accumulator in Java that we'll send all our updates through;
+ # they will be passed back to us through a TCP server
+ self._accumulatorServer = accumulators._start_update_server()
+ (host, port) = self._accumulatorServer.server_address
+ self._javaAccumulator = self._jsc.accumulator(
+ self.jvm.java.util.ArrayList(),
+ self.jvm.PythonAccumulatorParam(host, port))
+
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
@@ -74,6 +85,8 @@ class SparkContext(object):
def __del__(self):
if self._jsc:
self._jsc.stop()
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
def stop(self):
"""
@@ -110,6 +123,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
def union(self, rdds):
"""
Build the union of a list of RDDs.
@@ -129,6 +146,31 @@ class SparkContext(object):
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
+ def accumulator(self, value, accum_param=None):
+ """
+ Create an C{Accumulator} with the given initial value, using a given
+ AccumulatorParam helper object to define how to add values of the data
+ type if provided. Default AccumulatorParams are used for integers and
+ floating-point numbers if you do not provide one. For other types, the
+ AccumulatorParam must implement two methods:
+ - C{zero(value)}: provide a "zero value" for the type, compatible in
+ dimensions with the provided C{value} (e.g., a zero vector).
+ - C{addInPlace(val1, val2)}: add two values of the accumulator's data
+ type, returning a new value; for efficiency, can also update C{val1}
+ in place and return it.
+ """
+ if accum_param == None:
+ if isinstance(value, int):
+ accum_param = accumulators.INT_ACCUMULATOR_PARAM
+ elif isinstance(value, float):
+ accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
+ elif isinstance(value, complex):
+ accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
+ else:
+ raise Exception("No default accumulator param for type %s" % type(value))
+ SparkContext._next_accum_id += 1
+ return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
+
def addFile(self, path):
"""
Add a file to be downloaded into the working directory of this Spark
@@ -157,3 +199,15 @@ class SparkContext(object):
filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. The
+ directory must be a HDFS path if running on a cluster.
+
+ If the directory does not exist, it will be created. If the directory
+ exists and C{useExisting} is set to true, then the exisiting directory
+ will be used. Otherwise an exception will be thrown to prevent
+ accidental overriding of checkpoint files in the existing directory.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)