diff options
author | Matei Zaharia <matei@databricks.com> | 2013-12-29 14:31:45 -0500 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2013-12-29 14:32:05 -0500 |
commit | 615fb649d66b13371927a051d249433d746c5f19 (patch) | |
tree | 5a3b3487b46517765d31cdc0f2c2f340c714666d /python | |
parent | cd00225db9b90fc845fd1458831bdd9d014d1bb6 (diff) | |
download | spark-615fb649d66b13371927a051d249433d746c5f19.tar.gz spark-615fb649d66b13371927a051d249433d746c5f19.tar.bz2 spark-615fb649d66b13371927a051d249433d746c5f19.zip |
Fix some other Python tests due to initializing JVM in a different way
The test in context.py created two different instances of the
SparkContext class by copying "globals", so that some tests can have a
global "sc" object and others can try initializing their own contexts.
This led to two JVM gateways being created since SparkConf also looked
at pyspark.context.SparkContext to get the JVM.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/conf.py | 5 | ||||
-rw-r--r-- | python/pyspark/context.py | 23 | ||||
-rwxr-xr-x | python/run-tests | 1 |
3 files changed, 19 insertions, 10 deletions
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 56e615c287..eb7a6c13fe 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -50,10 +50,11 @@ u'value1' class SparkConf(object): - def __init__(self, loadDefaults=False): + def __init__(self, loadDefaults=True, _jvm=None): from pyspark.context import SparkContext SparkContext._ensure_initialized() - self._jconf = SparkContext._jvm.SparkConf(loadDefaults) + _jvm = _jvm or SparkContext._jvm + self._jconf = _jvm.SparkConf(loadDefaults) def set(self, key, value): self._jconf.set(key, value) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 97c1526afd..9d75c2b6f1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -81,7 +81,8 @@ class SparkContext(object): """ SparkContext._ensure_initialized(self) - self.conf = conf or SparkConf() + self.environment = environment or {} + self.conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer if batchSize == 1: @@ -90,23 +91,30 @@ class SparkContext(object): self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - # Set parameters passed directly on our conf; these operations will be no-ops - # if the parameters were None + # Set parameters passed directly to us on the conf; these operations will be + # no-ops if the parameters were None self.conf.setMaster(master) self.conf.setAppName(appName) self.conf.setSparkHome(sparkHome) - environment = environment or {} - for key, value in environment.iteritems(): - self.conf.setExecutorEnv(key, value) + if environment: + for key, value in environment.iteritems(): + self.conf.setExecutorEnv(key, value) + # Check that we have at least the required parameters if not self.conf.contains("spark.master"): raise Exception("A master URL must be set in your configuration") if not self.conf.contains("spark.appName"): raise Exception("An application name must be set in your configuration") + # Read back our properties from the conf in case we loaded some of them from + # the classpath or an external config file self.master = self.conf.get("spark.master") self.appName = self.conf.get("spark.appName") self.sparkHome = self.conf.getOrElse("spark.home", None) + for (k, v) in self.conf.getAll(): + if k.startswith("spark.executorEnv."): + varName = k[len("spark.executorEnv."):] + self.environment[varName] = v # Create the Java SparkContext through Py4J self._jsc = self._jvm.JavaSparkContext(self.conf._jconf) @@ -147,8 +155,7 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeToFile = \ - SparkContext._jvm.PythonRDD.writeToFile + SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: diff --git a/python/run-tests b/python/run-tests index d4dad672d2..a0898b3c21 100755 --- a/python/run-tests +++ b/python/run-tests @@ -35,6 +35,7 @@ function run_test() { run_test "pyspark/rdd.py" run_test "pyspark/context.py" +run_test "pyspark/conf.py" run_test "-m doctest pyspark/broadcast.py" run_test "-m doctest pyspark/accumulators.py" run_test "-m doctest pyspark/serializers.py" |