aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/conf.py71
-rw-r--r--python/pyspark/context.py16
-rw-r--r--python/pyspark/java_gateway.py13
3 files changed, 75 insertions, 25 deletions
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index 924da3eecf..64b6f238e9 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -52,6 +52,14 @@ spark.home=/path
>>> sorted(conf.getAll(), key=lambda p: p[0])
[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \
(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')]
+>>> conf._jconf.setExecutorEnv("VAR5", "value5")
+JavaObject id...
+>>> print(conf.toDebugString())
+spark.executorEnv.VAR1=value1
+spark.executorEnv.VAR3=value3
+spark.executorEnv.VAR4=value4
+spark.executorEnv.VAR5=value5
+spark.home=/path
"""
__all__ = ['SparkConf']
@@ -101,13 +109,24 @@ class SparkConf(object):
self._jconf = _jconf
else:
from pyspark.context import SparkContext
- SparkContext._ensure_initialized()
_jvm = _jvm or SparkContext._jvm
- self._jconf = _jvm.SparkConf(loadDefaults)
+
+ if _jvm is not None:
+ # JVM is created, so create self._jconf directly through JVM
+ self._jconf = _jvm.SparkConf(loadDefaults)
+ self._conf = None
+ else:
+ # JVM is not created, so store data in self._conf first
+ self._jconf = None
+ self._conf = {}
def set(self, key, value):
"""Set a configuration property."""
- self._jconf.set(key, unicode(value))
+ # Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet.
+ if self._jconf is not None:
+ self._jconf.set(key, unicode(value))
+ else:
+ self._conf[key] = unicode(value)
return self
def setIfMissing(self, key, value):
@@ -118,17 +137,17 @@ class SparkConf(object):
def setMaster(self, value):
"""Set master URL to connect to."""
- self._jconf.setMaster(value)
+ self.set("spark.master", value)
return self
def setAppName(self, value):
"""Set application name."""
- self._jconf.setAppName(value)
+ self.set("spark.app.name", value)
return self
def setSparkHome(self, value):
"""Set path where Spark is installed on worker nodes."""
- self._jconf.setSparkHome(value)
+ self.set("spark.home", value)
return self
def setExecutorEnv(self, key=None, value=None, pairs=None):
@@ -136,10 +155,10 @@ class SparkConf(object):
if (key is not None and pairs is not None) or (key is None and pairs is None):
raise Exception("Either pass one key-value pair or a list of pairs")
elif key is not None:
- self._jconf.setExecutorEnv(key, value)
+ self.set("spark.executorEnv." + key, value)
elif pairs is not None:
for (k, v) in pairs:
- self._jconf.setExecutorEnv(k, v)
+ self.set("spark.executorEnv." + k, v)
return self
def setAll(self, pairs):
@@ -149,35 +168,49 @@ class SparkConf(object):
:param pairs: list of key-value pairs to set
"""
for (k, v) in pairs:
- self._jconf.set(k, v)
+ self.set(k, v)
return self
def get(self, key, defaultValue=None):
"""Get the configured value for some key, or return a default otherwise."""
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
- if not self._jconf.contains(key):
- return None
- return self._jconf.get(key)
+ if self._jconf is not None:
+ if not self._jconf.contains(key):
+ return None
+ return self._jconf.get(key)
+ else:
+ if key not in self._conf:
+ return None
+ return self._conf[key]
else:
- return self._jconf.get(key, defaultValue)
+ if self._jconf is not None:
+ return self._jconf.get(key, defaultValue)
+ else:
+ return self._conf.get(key, defaultValue)
def getAll(self):
"""Get all values as a list of key-value pairs."""
- pairs = []
- for elem in self._jconf.getAll():
- pairs.append((elem._1(), elem._2()))
- return pairs
+ if self._jconf is not None:
+ return [(elem._1(), elem._2()) for elem in self._jconf.getAll()]
+ else:
+ return self._conf.items()
def contains(self, key):
"""Does this configuration contain a given key?"""
- return self._jconf.contains(key)
+ if self._jconf is not None:
+ return self._jconf.contains(key)
+ else:
+ return key in self._conf
def toDebugString(self):
"""
Returns a printable version of the configuration, as a list of
key=value pairs, one per line.
"""
- return self._jconf.toDebugString()
+ if self._jconf is not None:
+ return self._jconf.toDebugString()
+ else:
+ return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items())
def _test():
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a3dd1950a5..1b2e199c39 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -109,7 +109,7 @@ class SparkContext(object):
ValueError:...
"""
self._callsite = first_spark_call() or CallSite(None, None, None)
- SparkContext._ensure_initialized(self, gateway=gateway)
+ SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls)
@@ -121,7 +121,15 @@ class SparkContext(object):
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls):
self.environment = environment or {}
- self._conf = conf or SparkConf(_jvm=self._jvm)
+ # java gateway must have been launched at this point.
+ if conf is not None and conf._jconf is not None:
+ # conf has been initialized in JVM properly, so use conf directly. This represent the
+ # scenario that JVM has been launched before SparkConf is created (e.g. SparkContext is
+ # created and then stopped, and we create a new SparkConf and new SparkContext again)
+ self._conf = conf
+ else:
+ self._conf = SparkConf(_jvm=SparkContext._jvm)
+
self._batchSize = batchSize # -1 represents an unlimited batch size
self._unbatched_serializer = serializer
if batchSize == 0:
@@ -232,14 +240,14 @@ class SparkContext(object):
return self._jvm.JavaSparkContext(jconf)
@classmethod
- def _ensure_initialized(cls, instance=None, gateway=None):
+ def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
"""
Checks whether a SparkContext is initialized or not.
Throws error if a SparkContext is already running.
"""
with SparkContext._lock:
if not SparkContext._gateway:
- SparkContext._gateway = gateway or launch_gateway()
+ SparkContext._gateway = gateway or launch_gateway(conf)
SparkContext._jvm = SparkContext._gateway.jvm
if instance:
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index f76cadcf62..c1cf843d84 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -32,7 +32,12 @@ from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from pyspark.serializers import read_int
-def launch_gateway():
+def launch_gateway(conf=None):
+ """
+ launch jvm gateway
+ :param conf: spark configuration passed to spark-submit
+ :return:
+ """
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
else:
@@ -41,13 +46,17 @@ def launch_gateway():
# proper classpath and settings from spark-env.sh
on_windows = platform.system() == "Windows"
script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
+ command = [os.path.join(SPARK_HOME, script)]
+ if conf:
+ for k, v in conf.getAll():
+ command += ['--conf', '%s=%s' % (k, v)]
submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
if os.environ.get("SPARK_TESTING"):
submit_args = ' '.join([
"--conf spark.ui.enabled=false",
submit_args
])
- command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
+ command = command + shlex.split(submit_args)
# Start a socket that will be used by PythonGatewayServer to communicate its port to us
callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)