aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/conf.py15
-rw-r--r--python/pyspark/context.py17
2 files changed, 22 insertions, 10 deletions
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index 3870cd8f2b..49b68d57ab 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -75,7 +75,7 @@ class SparkConf(object):
and can no longer be modified by the user.
"""
- def __init__(self, loadDefaults=True, _jvm=None):
+ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
"""
Create a new Spark configuration.
@@ -83,11 +83,16 @@ class SparkConf(object):
properties (True by default)
@param _jvm: internal parameter used to pass a handle to the
Java VM; does not need to be set by users
+ @param _jconf: Optionally pass in an existing SparkConf handle
+ to use its parameters
"""
- from pyspark.context import SparkContext
- SparkContext._ensure_initialized()
- _jvm = _jvm or SparkContext._jvm
- self._jconf = _jvm.SparkConf(loadDefaults)
+ if _jconf:
+ self._jconf = _jconf
+ else:
+ from pyspark.context import SparkContext
+ SparkContext._ensure_initialized()
+ _jvm = _jvm or SparkContext._jvm
+ self._jconf = _jvm.SparkConf(loadDefaults)
def set(self, key, value):
"""Set a configuration property."""
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index f318b5d9a7..93faa2e385 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -51,7 +51,8 @@ class SparkContext(object):
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None):
+ environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
+ gateway=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@@ -70,6 +71,8 @@ class SparkContext(object):
unlimited batch size.
@param serializer: The serializer for RDDs.
@param conf: A L{SparkConf} object setting Spark properties.
+ @param gateway: Use an existing gateway and JVM, otherwise a new JVM
+ will be instatiated.
>>> from pyspark.context import SparkContext
@@ -80,7 +83,7 @@ class SparkContext(object):
...
ValueError:...
"""
- SparkContext._ensure_initialized(self)
+ SparkContext._ensure_initialized(self, gateway=gateway)
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
@@ -120,7 +123,7 @@ class SparkContext(object):
self.environment[varName] = v
# Create the Java SparkContext through Py4J
- self._jsc = self._jvm.JavaSparkContext(self._conf._jconf)
+ self._jsc = self._initialize_context(self._conf._jconf)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
@@ -152,11 +155,15 @@ class SparkContext(object):
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ # Initialize SparkContext in function to allow subclass specific initialization
+ def _initialize_context(self, jconf):
+ return self._jvm.JavaSparkContext(jconf)
+
@classmethod
- def _ensure_initialized(cls, instance=None):
+ def _ensure_initialized(cls, instance=None, gateway=None):
with SparkContext._lock:
if not SparkContext._gateway:
- SparkContext._gateway = launch_gateway()
+ SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile