aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py174
1 files changed, 124 insertions, 50 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 597110321a..f955aad7a4 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -24,9 +24,10 @@ from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
+from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
@@ -42,21 +43,22 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeIteratorToPickleFile = None
- _takePartition = None
+ _writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
- def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024):
+
+ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
+ environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None):
"""
- Create a new SparkContext.
+ Create a new SparkContext. At least the master and app name should be set,
+ either through the named parameters here or through C{conf}.
@param master: Cluster URL to connect to
(e.g. mesos://host:port, spark://host:port, local[4]).
- @param jobName: A name for your job, to display on the cluster web UI
+ @param appName: A name for your job, to display on the cluster web UI.
@param sparkHome: Location where Spark is installed on cluster nodes.
@param pyFiles: Collection of .zip or .py files to send to the cluster
and add to PYTHONPATH. These can be paths on the local file
@@ -66,29 +68,59 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
+ @param serializer: The serializer for RDDs.
+ @param conf: A L{SparkConf} object setting Spark properties.
+
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+
+ >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
- with SparkContext._lock:
- if SparkContext._active_spark_context:
- raise ValueError("Cannot run multiple SparkContexts at once")
- else:
- SparkContext._active_spark_context = self
- if not SparkContext._gateway:
- SparkContext._gateway = launch_gateway()
- SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeIteratorToPickleFile = \
- SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
- SparkContext._takePartition = \
- SparkContext._jvm.PythonRDD.takePartition
- self.master = master
- self.jobName = jobName
- self.sparkHome = sparkHome or None # None becomes null in Py4J
+ SparkContext._ensure_initialized(self)
+
self.environment = environment or {}
- self.batchSize = batchSize # -1 represents a unlimited batch size
+ self._conf = conf or SparkConf(_jvm=self._jvm)
+ self._batchSize = batchSize # -1 represents an unlimited batch size
+ self._unbatched_serializer = serializer
+ if batchSize == 1:
+ self.serializer = self._unbatched_serializer
+ else:
+ self.serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
+
+ # Set any parameters passed directly to us on the conf
+ if master:
+ self._conf.setMaster(master)
+ if appName:
+ self._conf.setAppName(appName)
+ if sparkHome:
+ self._conf.setSparkHome(sparkHome)
+ if environment:
+ for key, value in environment.iteritems():
+ self._conf.setExecutorEnv(key, value)
+
+ # Check that we have at least the required parameters
+ if not self._conf.contains("spark.master"):
+ raise Exception("A master URL must be set in your configuration")
+ if not self._conf.contains("spark.app.name"):
+ raise Exception("An application name must be set in your configuration")
+
+ # Read back our properties from the conf in case we loaded some of them from
+ # the classpath or an external config file
+ self.master = self._conf.get("spark.master")
+ self.appName = self._conf.get("spark.app.name")
+ self.sparkHome = self._conf.get("spark.home", None)
+ for (k, v) in self._conf.getAll():
+ if k.startswith("spark.executorEnv."):
+ varName = k[len("spark.executorEnv."):]
+ self.environment[varName] = v
# Create the Java SparkContext through Py4J
- empty_string_array = self._gateway.new_array(self._jvm.String, 0)
- self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
- empty_string_array)
+ self._jsc = self._jvm.JavaSparkContext(self._conf._jconf)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
@@ -99,6 +131,7 @@ class SparkContext(object):
self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
@@ -115,10 +148,33 @@ class SparkContext(object):
self.addPyFile(path)
# Create a temporary directory inside spark.local.dir:
- local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir()
+ local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ @classmethod
+ def _ensure_initialized(cls, instance=None):
+ with SparkContext._lock:
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
+
+ if instance:
+ if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = instance
+
+ @classmethod
+ def setSystemProperty(cls, key, value):
+ """
+ Set a Java system property, such as spark.executor.memory. This must
+ must be invoked before instantiating SparkContext.
+ """
+ SparkContext._ensure_initialized()
+ SparkContext._jvm.java.lang.System.setProperty(key, value)
+
@property
def defaultParallelism(self):
"""
@@ -158,15 +214,17 @@ class SparkContext(object):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = min(len(c) // numSlices, self.batchSize)
+ batchSize = min(len(c) // numSlices, self._batchSize)
if batchSize > 1:
- c = batched(c, batchSize)
- for x in c:
- write_with_length(dump_pickle(x), tempFile)
+ serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
+ else:
+ serializer = self._unbatched_serializer
+ serializer.dump_stream(c, tempFile)
tempFile.close()
- readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
- jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
- return RDD(jrdd, self)
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self, serializer)
def textFile(self, name, minSplits=None):
"""
@@ -175,29 +233,50 @@ class SparkContext(object):
RDD of Strings.
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- jrdd = self._jsc.textFile(name, minSplits)
- return RDD(jrdd, self)
+ return RDD(self._jsc.textFile(name, minSplits), self,
+ MUTF8Deserializer())
- def _checkpointFile(self, name):
+ def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
- return RDD(jrdd, self)
+ return RDD(jrdd, self, input_deserializer)
def union(self, rdds):
"""
Build the union of a list of RDDs.
+
+ This supports unions() of RDDs with different serialized formats,
+ although this forces them to be reserialized using the default
+ serializer:
+
+ >>> path = os.path.join(tempdir, "union-text.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("Hello")
+ >>> textFile = sc.textFile(path)
+ >>> textFile.collect()
+ [u'Hello']
+ >>> parallelized = sc.parallelize(["World!"])
+ >>> sorted(sc.union([textFile, parallelized]).collect())
+ [u'Hello', 'World!']
"""
+ first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+ if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+ rdds = [x._reserialize() for x in rdds]
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
- rest = ListConverter().convert(rest, self.gateway._gateway_client)
- return RDD(self._jsc.union(first, rest), self)
+ rest = ListConverter().convert(rest, self._gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self,
+ rdds[0]._jrdd_deserializer)
def broadcast(self, value):
"""
- Broadcast a read-only variable to the cluster, returning a C{Broadcast}
+ Broadcast a read-only variable to the cluster, returning a
+ L{Broadcast<pyspark.broadcast.Broadcast>}
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
"""
- jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ pickleSer = PickleSerializer()
+ pickled = pickleSer.dumps(value)
+ jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
@@ -209,7 +288,7 @@ class SparkContext(object):
and floating-point numbers if you do not provide one. For other types,
a custom AccumulatorParam can be used.
"""
- if accum_param == None:
+ if accum_param is None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):
@@ -268,17 +347,12 @@ class SparkContext(object):
self._python_includes.append(filename)
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
- def setCheckpointDir(self, dirName, useExisting=False):
+ def setCheckpointDir(self, dirName):
"""
Set the directory under which RDDs are going to be checkpointed. The
directory must be a HDFS path if running on a cluster.
-
- If the directory does not exist, it will be created. If the directory
- exists and C{useExisting} is set to true, then the exisiting directory
- will be used. Otherwise an exception will be thrown to prevent
- accidental overriding of checkpoint files in the existing directory.
"""
- self._jsc.sc().setCheckpointDir(dirName, useExisting)
+ self._jsc.sc().setCheckpointDir(dirName)
def _getJavaStorageLevel(self, storageLevel):
"""