diff options
-rw-r--r-- | python/pyspark/conf.py | 15 | ||||
-rw-r--r-- | python/pyspark/context.py | 17 |
2 files changed, 22 insertions, 10 deletions
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 3870cd8f2b..49b68d57ab 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -75,7 +75,7 @@ class SparkConf(object): and can no longer be modified by the user. """ - def __init__(self, loadDefaults=True, _jvm=None): + def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): """ Create a new Spark configuration. @@ -83,11 +83,16 @@ class SparkConf(object): properties (True by default) @param _jvm: internal parameter used to pass a handle to the Java VM; does not need to be set by users + @param _jconf: Optionally pass in an existing SparkConf handle + to use its parameters """ - from pyspark.context import SparkContext - SparkContext._ensure_initialized() - _jvm = _jvm or SparkContext._jvm - self._jconf = _jvm.SparkConf(loadDefaults) + if _jconf: + self._jconf = _jconf + else: + from pyspark.context import SparkContext + SparkContext._ensure_initialized() + _jvm = _jvm or SparkContext._jvm + self._jconf = _jvm.SparkConf(loadDefaults) def set(self, key, value): """Set a configuration property.""" diff --git a/python/pyspark/context.py b/python/pyspark/context.py index f318b5d9a7..93faa2e385 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -51,7 +51,8 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None): + environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + gateway=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -70,6 +71,8 @@ class SparkContext(object): unlimited batch size. @param serializer: The serializer for RDDs. @param conf: A L{SparkConf} object setting Spark properties. + @param gateway: Use an existing gateway and JVM, otherwise a new JVM + will be instatiated. >>> from pyspark.context import SparkContext @@ -80,7 +83,7 @@ class SparkContext(object): ... ValueError:... """ - SparkContext._ensure_initialized(self) + SparkContext._ensure_initialized(self, gateway=gateway) self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) @@ -120,7 +123,7 @@ class SparkContext(object): self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._jvm.JavaSparkContext(self._conf._jconf) + self._jsc = self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server @@ -152,11 +155,15 @@ class SparkContext(object): self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() + # Initialize SparkContext in function to allow subclass specific initialization + def _initialize_context(self, jconf): + return self._jvm.JavaSparkContext(jconf) + @classmethod - def _ensure_initialized(cls, instance=None): + def _ensure_initialized(cls, instance=None, gateway=None): with SparkContext._lock: if not SparkContext._gateway: - SparkContext._gateway = launch_gateway() + SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile |