diff options
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r-- | python/pyspark/context.py | 147 |
1 files changed, 126 insertions, 21 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..6831f9b7f8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,8 +1,13 @@ import os -import atexit +import shutil +import sys +from threading import Lock from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD @@ -17,11 +22,13 @@ class SparkContext(object): broadcast variables on that cluster. """ - gateway = launch_gateway() - jvm = gateway.jvm - _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile - _takePartition = jvm.PythonRDD.takePartition + _gateway = None + _jvm = None + _writeIteratorToPickleFile = None + _takePartition = None + _next_accum_id = 0 + _active_spark_context = None + _lock = Lock() def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -41,6 +48,18 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. """ + with SparkContext._lock: + if SparkContext._active_spark_context: + raise ValueError("Cannot run multiple SparkContexts at once") + else: + SparkContext._active_spark_context = self + if not SparkContext._gateway: + SparkContext._gateway = launch_gateway() + SparkContext._jvm = SparkContext._gateway.jvm + SparkContext._writeIteratorToPickleFile = \ + SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._takePartition = \ + SparkContext._jvm.PythonRDD.takePartition self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -48,10 +67,18 @@ class SparkContext(object): self.batchSize = batchSize # -1 represents a unlimited batch size # Create the Java SparkContext through Py4J - empty_string_array = self.gateway.new_array(self.jvm.String, 0) - self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array = self._gateway.new_array(self._jvm.String, 0) + self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self._jvm.java.util.ArrayList(), + self._jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -62,6 +89,13 @@ class SparkContext(object): # Deploy any code dependencies specified in the constructor for path in (pyFiles or []): self.addPyFile(path) + SparkFiles._sc = self + sys.path.append(SparkFiles.getRootDirectory()) + + # Create a temporary directory inside spark.local.dir: + local_dir = self._jvm.spark.Utils.getLocalDir() + self._temp_dir = \ + self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath() @property def defaultParallelism(self): @@ -72,15 +106,20 @@ class SparkContext(object): return self._jsc.sc().defaultParallelism() def __del__(self): - if self._jsc: - self._jsc.stop() + self.stop() def stop(self): """ Shut down the SparkContext. """ - self._jsc.stop() - self._jsc = None + if self._jsc: + self._jsc.stop() + self._jsc = None + if self._accumulatorServer: + self._accumulatorServer.shutdown() + self._accumulatorServer = None + with SparkContext._lock: + SparkContext._active_spark_context = None def parallelize(self, c, numSlices=None): """ @@ -90,14 +129,14 @@ class SparkContext(object): # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - atexit.register(lambda: os.unlink(tempFile.name)) + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) if self.batchSize != 1: c = batched(c, self.batchSize) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile + jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): @@ -110,6 +149,10 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def _checkpointFile(self, name): + jrdd = self._jsc.checkpointFile(name) + return RDD(jrdd, self) + def union(self, rdds): """ Build the union of a list of RDDs. @@ -129,12 +172,48 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an L{Accumulator} with the given initial value, using a given + L{AccumulatorParam} helper object to define how to add values of the + data type if provided. Default AccumulatorParams are used for integers + and floating-point numbers if you do not provide one. For other types, + a custom AccumulatorParam can be used. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. + Add a file to be downloaded with this Spark job on every node. + The C{path} passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use + L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its + download location. + + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("100") + >>> sc.addFile(path) + >>> def func(iterator): + ... with open(SparkFiles.get("test.txt")) as testFile: + ... fileVal = int(testFile.readline()) + ... return [x * 100 for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] """ self._jsc.sc().addFile(path) @@ -155,5 +234,31 @@ class SparkContext(object): """ self.addFile(path) filename = path.split("/")[-1] - os.environ["PYTHONPATH"] = \ - "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. The + directory must be a HDFS path if running on a cluster. + + If the directory does not exist, it will be created. If the directory + exists and C{useExisting} is set to true, then the exisiting directory + will be used. Otherwise an exception will be thrown to prevent + accidental overriding of checkpoint files in the existing directory. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) + + +def _test(): + import atexit + import doctest + import tempfile + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['tempdir'] = tempfile.mkdtemp() + atexit.register(lambda: shutil.rmtree(globs['tempdir'])) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() |