diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-02-01 11:09:56 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-02-01 11:13:10 -0800 |
commit | 9cc6ff9c4e7eec2d62261fc166ad2ebade148752 (patch) | |
tree | 2aea685655da9a2ed0acf4d7a40f81882e10b1e7 /python/pyspark/context.py | |
parent | 571af31304bd72d310c3b47a8471a4de206aa6fe (diff) | |
download | spark-9cc6ff9c4e7eec2d62261fc166ad2ebade148752.tar.gz spark-9cc6ff9c4e7eec2d62261fc166ad2ebade148752.tar.bz2 spark-9cc6ff9c4e7eec2d62261fc166ad2ebade148752.zip |
Do not launch JavaGateways on workers (SPARK-674).
The problem was that the gateway was being initialized whenever the
pyspark.context module was loaded. The fix uses lazy initialization
that occurs only when SparkContext instances are actually constructed.
I also made the gateway and jvm variables private.
This change results in ~3-4x performance improvement when running the
PySpark unit tests.
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r-- | python/pyspark/context.py | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 783e3dc148..ba6896dda3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -24,11 +24,10 @@ class SparkContext(object): broadcast variables on that cluster. """ - gateway = launch_gateway() - jvm = gateway.jvm - _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile - _takePartition = jvm.PythonRDD.takePartition + _gateway = None + _jvm = None + _writeIteratorToPickleFile = None + _takePartition = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() @@ -56,6 +55,13 @@ class SparkContext(object): raise ValueError("Cannot run multiple SparkContexts at once") else: SparkContext._active_spark_context = self + if not SparkContext._gateway: + SparkContext._gateway = launch_gateway() + SparkContext._jvm = SparkContext._gateway.jvm + SparkContext._writeIteratorToPickleFile = \ + SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._takePartition = \ + SparkContext._jvm.PythonRDD.takePartition self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -63,8 +69,8 @@ class SparkContext(object): self.batchSize = batchSize # -1 represents a unlimited batch size # Create the Java SparkContext through Py4J - empty_string_array = self.gateway.new_array(self.jvm.String, 0) - self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array = self._gateway.new_array(self._jvm.String, 0) + self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) # Create a single Accumulator in Java that we'll send all our updates through; @@ -72,8 +78,8 @@ class SparkContext(object): self._accumulatorServer = accumulators._start_update_server() (host, port) = self._accumulatorServer.server_address self._javaAccumulator = self._jsc.accumulator( - self.jvm.java.util.ArrayList(), - self.jvm.PythonAccumulatorParam(host, port)) + self._jvm.java.util.ArrayList(), + self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. @@ -127,7 +133,8 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile + jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): |