aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py139
1 files changed, 104 insertions, 35 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1e2f845f9c..657fe6f989 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,10 +1,13 @@
import os
-import atexit
+import shutil
+import sys
+from threading import Lock
from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
+from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD
@@ -19,12 +22,13 @@ class SparkContext(object):
broadcast variables on that cluster.
"""
- gateway = launch_gateway()
- jvm = gateway.jvm
- _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
- _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
- _takePartition = jvm.PythonRDD.takePartition
+ _gateway = None
+ _jvm = None
+ _writeIteratorToPickleFile = None
+ _takePartition = None
_next_accum_id = 0
+ _active_spark_context = None
+ _lock = Lock()
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -44,6 +48,18 @@ class SparkContext(object):
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
"""
+ with SparkContext._lock:
+ if SparkContext._active_spark_context:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = self
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeIteratorToPickleFile = \
+ SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._takePartition = \
+ SparkContext._jvm.PythonRDD.takePartition
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@@ -51,8 +67,8 @@ class SparkContext(object):
self.batchSize = batchSize # -1 represents a unlimited batch size
# Create the Java SparkContext through Py4J
- empty_string_array = self.gateway.new_array(self.jvm.String, 0)
- self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+ empty_string_array = self._gateway.new_array(self._jvm.String, 0)
+ self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
# Create a single Accumulator in Java that we'll send all our updates through;
@@ -60,8 +76,8 @@ class SparkContext(object):
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
- self.jvm.java.util.ArrayList(),
- self.jvm.PythonAccumulatorParam(host, port))
+ self._jvm.java.util.ArrayList(),
+ self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
@@ -73,6 +89,13 @@ class SparkContext(object):
# Deploy any code dependencies specified in the constructor
for path in (pyFiles or []):
self.addPyFile(path)
+ SparkFiles._sc = self
+ sys.path.append(SparkFiles.getRootDirectory())
+
+ # Create a temporary directory inside spark.local.dir:
+ local_dir = self._jvm.spark.Utils.getLocalDir()
+ self._temp_dir = \
+ self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
@property
def defaultParallelism(self):
@@ -83,17 +106,20 @@ class SparkContext(object):
return self._jsc.sc().defaultParallelism()
def __del__(self):
- if self._jsc:
- self._jsc.stop()
- if self._accumulatorServer:
- self._accumulatorServer.shutdown()
+ self.stop()
def stop(self):
"""
Shut down the SparkContext.
"""
- self._jsc.stop()
- self._jsc = None
+ if self._jsc:
+ self._jsc.stop()
+ self._jsc = None
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
+ self._accumulatorServer = None
+ with SparkContext._lock:
+ SparkContext._active_spark_context = None
def parallelize(self, c, numSlices=None):
"""
@@ -103,14 +129,14 @@ class SparkContext(object):
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
- tempFile = NamedTemporaryFile(delete=False)
- atexit.register(lambda: os.unlink(tempFile.name))
+ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
if self.batchSize != 1:
c = batched(c, self.batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
- jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
+ jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
@@ -123,6 +149,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
def union(self, rdds):
"""
Build the union of a list of RDDs.
@@ -144,16 +174,11 @@ class SparkContext(object):
def accumulator(self, value, accum_param=None):
"""
- Create an C{Accumulator} with the given initial value, using a given
- AccumulatorParam helper object to define how to add values of the data
- type if provided. Default AccumulatorParams are used for integers and
- floating-point numbers if you do not provide one. For other types, the
- AccumulatorParam must implement two methods:
- - C{zero(value)}: provide a "zero value" for the type, compatible in
- dimensions with the provided C{value} (e.g., a zero vector).
- - C{addInPlace(val1, val2)}: add two values of the accumulator's data
- type, returning a new value; for efficiency, can also update C{val1}
- in place and return it.
+ Create an L{Accumulator} with the given initial value, using a given
+ L{AccumulatorParam} helper object to define how to add values of the
+ data type if provided. Default AccumulatorParams are used for integers
+ and floating-point numbers if you do not provide one. For other types,
+ a custom AccumulatorParam can be used.
"""
if accum_param == None:
if isinstance(value, int):
@@ -169,10 +194,26 @@ class SparkContext(object):
def addFile(self, path):
"""
- Add a file to be downloaded into the working directory of this Spark
- job on every node. The C{path} passed can be either a local file,
- a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
- HTTPS or FTP URI.
+ Add a file to be downloaded with this Spark job on every node.
+ The C{path} passed can be either a local file, a file in HDFS
+ (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
+ FTP URI.
+
+ To access the file in Spark jobs, use
+ L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
+ download location.
+
+ >>> from pyspark import SparkFiles
+ >>> path = os.path.join(tempdir, "test.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("100")
+ >>> sc.addFile(path)
+ >>> def func(iterator):
+ ... with open(SparkFiles.get("test.txt")) as testFile:
+ ... fileVal = int(testFile.readline())
+ ... return [x * 100 for x in iterator]
+ >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
+ [100, 200, 300, 400]
"""
self._jsc.sc().addFile(path)
@@ -193,5 +234,33 @@ class SparkContext(object):
"""
self.addFile(path)
filename = path.split("/")[-1]
- os.environ["PYTHONPATH"] = \
- "%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. The
+ directory must be a HDFS path if running on a cluster.
+
+ If the directory does not exist, it will be created. If the directory
+ exists and C{useExisting} is set to true, then the exisiting directory
+ will be used. Otherwise an exception will be thrown to prevent
+ accidental overriding of checkpoint files in the existing directory.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)
+
+
+def _test():
+ import atexit
+ import doctest
+ import tempfile
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['tempdir'] = tempfile.mkdtemp()
+ atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+ (failure_count, test_count) = doctest.testmod(globs=globs)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()