aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-05-18 12:55:13 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-05-18 12:55:13 -0700
commit32fbd297dd651ba3ce4ce52aeb0488233149cdf9 (patch)
treec18970db228448c856f8ce9538d708511e9e95fd /python
parent9dadf019b93038e1e18336ccd06c5eecb4bae32f (diff)
downloadspark-32fbd297dd651ba3ce4ce52aeb0488233149cdf9.tar.gz
spark-32fbd297dd651ba3ce4ce52aeb0488233149cdf9.tar.bz2
spark-32fbd297dd651ba3ce4ce52aeb0488233149cdf9.zip
[SPARK-6216] [PYSPARK] check python version of worker with driver
This PR revert #5404, change to pass the version of python in driver into JVM, check it in worker before deserializing closure, then it can works with different major version of Python. Author: Davies Liu <davies@databricks.com> Closes #6203 from davies/py_version and squashes the following commits: b8fb76e [Davies Liu] fix test 6ce5096 [Davies Liu] use string for version 47c6278 [Davies Liu] check python version of worker with driver
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/context.py1
-rw-r--r--python/pyspark/rdd.py4
-rw-r--r--python/pyspark/sql/context.py1
-rw-r--r--python/pyspark/sql/functions.py4
-rw-r--r--python/pyspark/tests.py6
-rw-r--r--python/pyspark/worker.py12
6 files changed, 16 insertions, 12 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 31992795a9..d25ee85523 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -173,6 +173,7 @@ class SparkContext(object):
self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+ self.pythonVer = "%d.%d" % sys.version_info[:2]
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 545c5ad20c..70db4bbe4c 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2260,7 +2260,7 @@ class RDD(object):
def _prepare_for_python_RDD(sc, command, obj=None):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
- pickled_command = ser.dumps((command, sys.version_info[:2]))
+ pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
# The broadcast will have same life cycle as created PythonRDD
broadcast = sc.broadcast(pickled_command)
@@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD):
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
- self.ctx.pythonExec,
+ self.ctx.pythonExec, self.ctx.pythonVer,
bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index f6f107ca32..0bde719124 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -157,6 +157,7 @@ class SQLContext(object):
env,
includes,
self._sc.pythonExec,
+ self._sc.pythonVer,
bvars,
self._sc._javaAccumulator,
returnType.json())
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8d0e766ecd..fbe9bf5b52 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -353,8 +353,8 @@ class UserDefinedFunction(object):
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
- judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
- includes, sc.pythonExec, broadcast_vars,
+ judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
+ sc.pythonExec, sc.pythonVer, broadcast_vars,
sc._javaAccumulator, jdt)
return judf
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 09de4d159f..5e023f6c53 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase):
def test_with_different_versions_of_python(self):
rdd = self.sc.parallelize(range(10))
rdd.count()
- version = sys.version_info
- sys.version_info = (2, 0, 0)
+ version = self.sc.pythonVer
+ self.sc.pythonVer = "2.0"
try:
with QuietTest(self.sc):
self.assertRaises(Py4JJavaError, lambda: rdd.count())
finally:
- sys.version_info = version
+ self.sc.pythonVer = version
class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index fbdaf3a581..93df9002be 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,6 +57,12 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
exit(-1)
+ version = utf8_deserializer.loads(infile)
+ if version != "%d.%d" % sys.version_info[:2]:
+ raise Exception(("Python in worker has different version %s than that in " +
+ "driver %s, PySpark cannot run with different minor versions") %
+ ("%d.%d" % sys.version_info[:2], version))
+
# initialize global state
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
@@ -92,11 +98,7 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
- (func, profiler, deserializer, serializer), version = command
- if version != sys.version_info[:2]:
- raise Exception(("Python in worker has different version %s than that in " +
- "driver %s, PySpark cannot run with different minor versions") %
- (sys.version_info[:2], version))
+ func, profiler, deserializer, serializer = command
init_time = time.time()
def process():