aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-02-01 11:09:56 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-02-01 11:13:10 -0800
commit9cc6ff9c4e7eec2d62261fc166ad2ebade148752 (patch)
tree2aea685655da9a2ed0acf4d7a40f81882e10b1e7 /python/pyspark/context.py
parent571af31304bd72d310c3b47a8471a4de206aa6fe (diff)
downloadspark-9cc6ff9c4e7eec2d62261fc166ad2ebade148752.tar.gz
spark-9cc6ff9c4e7eec2d62261fc166ad2ebade148752.tar.bz2
spark-9cc6ff9c4e7eec2d62261fc166ad2ebade148752.zip
Do not launch JavaGateways on workers (SPARK-674).
The problem was that the gateway was being initialized whenever the pyspark.context module was loaded. The fix uses lazy initialization that occurs only when SparkContext instances are actually constructed. I also made the gateway and jvm variables private. This change results in ~3-4x performance improvement when running the PySpark unit tests.
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 783e3dc148..ba6896dda3 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -24,11 +24,10 @@ class SparkContext(object):
broadcast variables on that cluster.
"""
- gateway = launch_gateway()
- jvm = gateway.jvm
- _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
- _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
- _takePartition = jvm.PythonRDD.takePartition
+ _gateway = None
+ _jvm = None
+ _writeIteratorToPickleFile = None
+ _takePartition = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
@@ -56,6 +55,13 @@ class SparkContext(object):
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = self
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeIteratorToPickleFile = \
+ SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._takePartition = \
+ SparkContext._jvm.PythonRDD.takePartition
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@@ -63,8 +69,8 @@ class SparkContext(object):
self.batchSize = batchSize # -1 represents a unlimited batch size
# Create the Java SparkContext through Py4J
- empty_string_array = self.gateway.new_array(self.jvm.String, 0)
- self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+ empty_string_array = self._gateway.new_array(self._jvm.String, 0)
+ self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
# Create a single Accumulator in Java that we'll send all our updates through;
@@ -72,8 +78,8 @@ class SparkContext(object):
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
- self.jvm.java.util.ArrayList(),
- self.jvm.PythonAccumulatorParam(host, port))
+ self._jvm.java.util.ArrayList(),
+ self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
@@ -127,7 +133,8 @@ class SparkContext(object):
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
- jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
+ jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):